mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Matcher support for Span as well as Doc (#5113)
* Matcher support for Span, as well as Doc #5056 * Removes an import unused * Signed contributors agreement * Code optimization and better test * Add error message for bad Matcher call argument * Fix merging
This commit is contained in:
parent
1eef60c658
commit
1ca32d8f9c
2
.github/contributors/paoloq.md
vendored
2
.github/contributors/paoloq.md
vendored
|
@ -5,7 +5,7 @@ This spaCy Contributor Agreement (**"SCA"**) is based on the
|
||||||
The SCA applies to any contribution that you make to any product or project
|
The SCA applies to any contribution that you make to any product or project
|
||||||
managed by us (the **"project"**), and sets out the intellectual property rights
|
managed by us (the **"project"**), and sets out the intellectual property rights
|
||||||
you grant to us in the contributed materials. The term **"us"** shall mean
|
you grant to us in the contributed materials. The term **"us"** shall mean
|
||||||
[ExplosionAI UG (haftungsbeschränkt)](https://explosion.ai/legal). The term
|
[ExplosionAI GmbH](https://explosion.ai/legal). The term
|
||||||
**"you"** shall mean the person or entity identified below.
|
**"you"** shall mean the person or entity identified below.
|
||||||
|
|
||||||
If you agree to be bound by these terms, fill in the information requested
|
If you agree to be bound by these terms, fill in the information requested
|
||||||
|
|
|
@ -556,6 +556,7 @@ class Errors(object):
|
||||||
"({new_dim}) is not the same as the current vector dimension "
|
"({new_dim}) is not the same as the current vector dimension "
|
||||||
"({curr_dim}).")
|
"({curr_dim}).")
|
||||||
E194 = ("Unable to aligned mismatched text '{text}' and words '{words}'.")
|
E194 = ("Unable to aligned mismatched text '{text}' and words '{words}'.")
|
||||||
|
E195 = ("Matcher can be called on {good} only, got {got}.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -14,6 +14,7 @@ from ..typedefs cimport attr_t
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
from ..tokens.doc cimport Doc, get_token_attr
|
from ..tokens.doc cimport Doc, get_token_attr
|
||||||
|
from ..tokens.span cimport Span
|
||||||
from ..tokens.token cimport Token
|
from ..tokens.token cimport Token
|
||||||
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
|
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
|
||||||
|
|
||||||
|
@ -211,22 +212,29 @@ cdef class Matcher:
|
||||||
else:
|
else:
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
def __call__(self, Doc doc):
|
def __call__(self, object doc_or_span):
|
||||||
"""Find all token sequences matching the supplied pattern.
|
"""Find all token sequences matching the supplied pattern.
|
||||||
|
|
||||||
doc (Doc): The document to match over.
|
doc_or_span (Doc or Span): The document to match over.
|
||||||
RETURNS (list): A list of `(key, start, end)` tuples,
|
RETURNS (list): A list of `(key, start, end)` tuples,
|
||||||
describing the matches. A match tuple describes a span
|
describing the matches. A match tuple describes a span
|
||||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(doc_or_span, Doc):
|
||||||
|
doc = doc_or_span
|
||||||
|
length = len(doc)
|
||||||
|
elif isinstance(doc_or_span, Span):
|
||||||
|
doc = doc_or_span.doc
|
||||||
|
length = doc_or_span.end - doc_or_span.start
|
||||||
|
else:
|
||||||
|
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doc_or_span).__name__))
|
||||||
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
|
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
|
||||||
and not doc.is_tagged:
|
and not doc.is_tagged:
|
||||||
raise ValueError(Errors.E155.format())
|
raise ValueError(Errors.E155.format())
|
||||||
if DEP in self._seen_attrs and not doc.is_parsed:
|
if DEP in self._seen_attrs and not doc.is_parsed:
|
||||||
raise ValueError(Errors.E156.format())
|
raise ValueError(Errors.E156.format())
|
||||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
|
matches = find_matches(&self.patterns[0], self.patterns.size(), doc_or_span, length,
|
||||||
extensions=self._extensions,
|
extensions=self._extensions, predicates=self._extra_predicates)
|
||||||
predicates=self._extra_predicates)
|
|
||||||
for i, (key, start, end) in enumerate(matches):
|
for i, (key, start, end) in enumerate(matches):
|
||||||
on_match = self._callbacks.get(key, None)
|
on_match = self._callbacks.get(key, None)
|
||||||
if on_match is not None:
|
if on_match is not None:
|
||||||
|
@ -248,9 +256,7 @@ def unpickle_matcher(vocab, patterns, callbacks):
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
cdef find_matches(TokenPatternC** patterns, int n, object doc_or_span, int length, extensions=None, predicates=tuple()):
|
||||||
cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
|
||||||
predicates=tuple()):
|
|
||||||
"""Find matches in a doc, with a compiled array of patterns. Matches are
|
"""Find matches in a doc, with a compiled array of patterns. Matches are
|
||||||
returned as a list of (id, start, end) tuples.
|
returned as a list of (id, start, end) tuples.
|
||||||
|
|
||||||
|
@ -268,18 +274,18 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
||||||
cdef int i, j, nr_extra_attr
|
cdef int i, j, nr_extra_attr
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
output = []
|
output = []
|
||||||
if doc.length == 0:
|
if length == 0:
|
||||||
# avoid any processing or mem alloc if the document is empty
|
# avoid any processing or mem alloc if the document is empty
|
||||||
return output
|
return output
|
||||||
if len(predicates) > 0:
|
if len(predicates) > 0:
|
||||||
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
|
predicate_cache = <char*>mem.alloc(length * len(predicates), sizeof(char))
|
||||||
if extensions is not None and len(extensions) >= 1:
|
if extensions is not None and len(extensions) >= 1:
|
||||||
nr_extra_attr = max(extensions.values()) + 1
|
nr_extra_attr = max(extensions.values()) + 1
|
||||||
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
|
extra_attr_values = <attr_t*>mem.alloc(length * nr_extra_attr, sizeof(attr_t))
|
||||||
else:
|
else:
|
||||||
nr_extra_attr = 0
|
nr_extra_attr = 0
|
||||||
extra_attr_values = <attr_t*>mem.alloc(doc.length, sizeof(attr_t))
|
extra_attr_values = <attr_t*>mem.alloc(length, sizeof(attr_t))
|
||||||
for i, token in enumerate(doc):
|
for i, token in enumerate(doc_or_span):
|
||||||
for name, index in extensions.items():
|
for name, index in extensions.items():
|
||||||
value = token._.get(name)
|
value = token._.get(name)
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, basestring):
|
||||||
|
@ -287,11 +293,11 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
||||||
extra_attr_values[i * nr_extra_attr + index] = value
|
extra_attr_values[i * nr_extra_attr + index] = value
|
||||||
# Main loop
|
# Main loop
|
||||||
cdef int nr_predicate = len(predicates)
|
cdef int nr_predicate = len(predicates)
|
||||||
for i in range(doc.length):
|
for i in range(length):
|
||||||
for j in range(n):
|
for j in range(n):
|
||||||
states.push_back(PatternStateC(patterns[j], i, 0))
|
states.push_back(PatternStateC(patterns[j], i, 0))
|
||||||
transition_states(states, matches, predicate_cache,
|
transition_states(states, matches, predicate_cache,
|
||||||
doc[i], extra_attr_values, predicates)
|
doc_or_span[i], extra_attr_values, predicates)
|
||||||
extra_attr_values += nr_extra_attr
|
extra_attr_values += nr_extra_attr
|
||||||
predicate_cache += len(predicates)
|
predicate_cache += len(predicates)
|
||||||
# Handle matches that end in 0-width patterns
|
# Handle matches that end in 0-width patterns
|
||||||
|
|
|
@ -6,7 +6,6 @@ import re
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
from spacy.matcher import Matcher, DependencyMatcher
|
from spacy.matcher import Matcher, DependencyMatcher
|
||||||
from spacy.tokens import Doc, Token
|
from spacy.tokens import Doc, Token
|
||||||
|
|
||||||
from ..doc.test_underscore import clean_underscore # noqa: F401
|
from ..doc.test_underscore import clean_underscore # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
@ -470,3 +469,13 @@ def test_matcher_callback(en_vocab):
|
||||||
doc = Doc(en_vocab, words=["This", "is", "a", "test", "."])
|
doc = Doc(en_vocab, words=["This", "is", "a", "test", "."])
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
mock.assert_called_once_with(matcher, doc, 0, matches)
|
mock.assert_called_once_with(matcher, doc, 0, matches)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_span(matcher):
|
||||||
|
text = "JavaScript is good but Java is better"
|
||||||
|
doc = Doc(matcher.vocab, words=text.split())
|
||||||
|
span_js = doc[:3]
|
||||||
|
span_java = doc[4:]
|
||||||
|
assert len(matcher(doc)) == 2
|
||||||
|
assert len(matcher(span_js)) == 1
|
||||||
|
assert len(matcher(span_java)) == 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user