mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Update Matcher API
This commit is contained in:
parent
8c9b3d5ad7
commit
ce9234f593
|
@ -87,7 +87,7 @@ ctypedef TokenPatternC* TokenPatternC_ptr
|
||||||
ctypedef pair[int, TokenPatternC_ptr] StateC
|
ctypedef pair[int, TokenPatternC_ptr] StateC
|
||||||
|
|
||||||
|
|
||||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label,
|
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
|
||||||
object token_specs) except NULL:
|
object token_specs) except NULL:
|
||||||
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
||||||
cdef int i
|
cdef int i
|
||||||
|
@ -99,15 +99,21 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label,
|
||||||
pattern[i].attrs[j].attr = attr
|
pattern[i].attrs[j].attr = attr
|
||||||
pattern[i].attrs[j].value = value
|
pattern[i].attrs[j].value = value
|
||||||
i = len(token_specs)
|
i = len(token_specs)
|
||||||
pattern[i].attrs = <AttrValueC*>mem.alloc(3, sizeof(AttrValueC))
|
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
|
||||||
pattern[i].attrs[0].attr = ID
|
pattern[i].attrs[0].attr = ID
|
||||||
pattern[i].attrs[0].value = entity_id
|
pattern[i].attrs[0].value = entity_id
|
||||||
pattern[i].attrs[1].attr = ENT_TYPE
|
|
||||||
pattern[i].attrs[1].value = label
|
|
||||||
pattern[i].nr_attr = 0
|
pattern[i].nr_attr = 0
|
||||||
return pattern
|
return pattern
|
||||||
|
|
||||||
|
|
||||||
|
cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0:
|
||||||
|
while pattern.nr_attr != 0:
|
||||||
|
pattern += 1
|
||||||
|
id_attr = pattern[0].attrs[0]
|
||||||
|
assert id_attr.attr == ID
|
||||||
|
return id_attr.value
|
||||||
|
|
||||||
|
|
||||||
cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil:
|
cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil:
|
||||||
for attr in pattern.attrs[:pattern.nr_attr]:
|
for attr in pattern.attrs[:pattern.nr_attr]:
|
||||||
if get_token_attr(token, attr.attr) != attr.value:
|
if get_token_attr(token, attr.attr) != attr.value:
|
||||||
|
@ -175,12 +181,11 @@ cdef class Matcher:
|
||||||
cdef public object _callbacks
|
cdef public object _callbacks
|
||||||
cdef public object _acceptors
|
cdef public object _acceptors
|
||||||
|
|
||||||
def __init__(self, vocab, patterns={}):
|
def __init__(self, vocab):
|
||||||
"""Create the Matcher.
|
"""Create the Matcher.
|
||||||
|
|
||||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||||
documents the matcher will operate on.
|
documents the matcher will operate on.
|
||||||
patterns (dict): Patterns to add to the matcher.
|
|
||||||
RETURNS (Matcher): The newly constructed object.
|
RETURNS (Matcher): The newly constructed object.
|
||||||
"""
|
"""
|
||||||
self._patterns = {}
|
self._patterns = {}
|
||||||
|
@ -189,123 +194,105 @@ cdef class Matcher:
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
|
|
||||||
self.add_entity(entity_key, attrs)
|
|
||||||
for spec in specs:
|
|
||||||
self.add_pattern(entity_key, spec, label=etype)
|
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (self.__class__, (self.vocab, self._patterns), None, None)
|
return (self.__class__, (self.vocab, self._patterns), None, None)
|
||||||
|
|
||||||
property n_patterns:
|
def __len__(self):
|
||||||
def __get__(self): return self.patterns.size()
|
return len(self._patterns)
|
||||||
|
|
||||||
def add_entity(self, entity_key, attrs=None, if_exists='raise',
|
def __contains__(self, key):
|
||||||
acceptor=None, on_match=None):
|
return len(self._patterns)
|
||||||
# TODO: replace with new Matcher.add()
|
|
||||||
"""Add an entity to the matcher.
|
|
||||||
|
|
||||||
entity_key (unicode or int): An ID for the entity.
|
def add(self, key, on_match, *patterns):
|
||||||
attrs (dict): Attributes to associate with the `Matcher`.
|
"""Add a match-rule to the matcher.
|
||||||
if_exists (unicode): `'raise'`, `'ignore'` or `'update'`. Controls what
|
|
||||||
happens if the entity ID already exists. Defaults to `'raise'`.
|
A match-rule consists of: an ID key, an on_match callback,
|
||||||
acceptor (function): Callback function to filter matches of the entity.
|
and one or more patterns. If the key exists, the patterns
|
||||||
on_match (function): Callback function to act on matches of the entity.
|
are appended to the previous ones, and the previous on_match
|
||||||
|
callback is replaced.
|
||||||
|
|
||||||
|
The on_match callback will receive the arguments
|
||||||
|
(matcher, doc, i, matches). Note that if no `on_match`
|
||||||
|
callback is specified, the document will not be modified.
|
||||||
|
|
||||||
|
A pattern consists of one or more token_specs,
|
||||||
|
where a token_spec is a dictionary mapping
|
||||||
|
attribute IDs to values. Token descriptors can also
|
||||||
|
include quantifiers. There are currently important
|
||||||
|
known problems with the quantifiers --- see the docs.
|
||||||
"""
|
"""
|
||||||
if if_exists not in ('raise', 'ignore', 'update'):
|
for pattern in patterns:
|
||||||
raise ValueError(
|
if len(pattern) == 0:
|
||||||
"Unexpected value for if_exists: %s.\n"
|
msg = ("Cannot add pattern for zero tokens to matcher.\n"
|
||||||
"Expected one of: ['raise', 'ignore', 'update']" % if_exists)
|
"key: {key}\n")
|
||||||
if attrs is None:
|
raise ValueError(msg.format(key=key))
|
||||||
attrs = {}
|
key = self._normalize_key(key)
|
||||||
entity_key = self.normalize_entity_key(entity_key)
|
self._patterns.setdefault(key, [])
|
||||||
if self.has_entity(entity_key):
|
self._callbacks[key] = on_match
|
||||||
if if_exists == 'raise':
|
|
||||||
raise KeyError(
|
|
||||||
"Tried to add entity %s. Entity exists, and if_exists='raise'.\n"
|
|
||||||
"Set if_exists='ignore' or if_exists='update', or check with "
|
|
||||||
"matcher.has_entity()")
|
|
||||||
elif if_exists == 'ignore':
|
|
||||||
return
|
|
||||||
self._entities[entity_key] = dict(attrs)
|
|
||||||
self._patterns.setdefault(entity_key, [])
|
|
||||||
self._acceptors[entity_key] = acceptor
|
|
||||||
self._callbacks[entity_key] = on_match
|
|
||||||
|
|
||||||
def add_pattern(self, entity_key, token_specs, label=""):
|
for pattern in patterns:
|
||||||
# TODO: replace with new Matcher.add()
|
specs = _convert_strings(pattern, self.vocab.strings)
|
||||||
"""Add a pattern to the matcher.
|
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
||||||
|
self._patterns[key].append(specs)
|
||||||
|
|
||||||
entity_key (unicode): An ID for the entity.
|
def remove(self, key):
|
||||||
token_specs (list): Description of the pattern to be matched.
|
"""Remove a rule from the matcher.
|
||||||
label (unicode): Label to assign to the matched pattern. Defaults to `""`.
|
|
||||||
|
A KeyError is raised if the key does not exist.
|
||||||
"""
|
"""
|
||||||
token_specs = list(token_specs)
|
key = self._normalize_key(key)
|
||||||
if len(token_specs) == 0:
|
self._patterns.pop(key)
|
||||||
msg = ("Cannot add pattern for zero tokens to matcher.\n"
|
self._callbacks.pop(key)
|
||||||
"entity_key: {entity_key}\n"
|
cdef int i = 0
|
||||||
"label: {label}")
|
while i < self.patterns.size():
|
||||||
raise ValueError(msg.format(entity_key=entity_key, label=label))
|
pattern_key = get_pattern_key(self.patterns.at(i))
|
||||||
entity_key = self.normalize_entity_key(entity_key)
|
if pattern_key == key:
|
||||||
if not self.has_entity(entity_key):
|
self.patterns.erase(self.patterns.begin()+i)
|
||||||
self.add_entity(entity_key)
|
else:
|
||||||
if isinstance(label, basestring):
|
i += 1
|
||||||
label = self.vocab.strings[label]
|
|
||||||
elif label is None:
|
|
||||||
label = 0
|
|
||||||
spec = _convert_strings(token_specs, self.vocab.strings)
|
|
||||||
|
|
||||||
self.patterns.push_back(init_pattern(self.mem, entity_key, label, spec))
|
def has_key(self, key):
|
||||||
self._patterns[entity_key].append((label, token_specs))
|
"""Check whether the matcher has a rule with a given key.
|
||||||
|
|
||||||
def add(self, entity_key, label, attrs, specs, acceptor=None, on_match=None):
|
key (string or int): The key to check.
|
||||||
# TODO: replace with new Matcher.add()
|
RETURNS (bool): Whether the matcher has the rule.
|
||||||
self.add_entity(entity_key, attrs=attrs, if_exists='update',
|
|
||||||
acceptor=acceptor, on_match=on_match)
|
|
||||||
for spec in specs:
|
|
||||||
self.add_pattern(entity_key, spec, label=label)
|
|
||||||
|
|
||||||
def normalize_entity_key(self, entity_key):
|
|
||||||
if isinstance(entity_key, basestring):
|
|
||||||
return self.vocab.strings[entity_key]
|
|
||||||
else:
|
|
||||||
return entity_key
|
|
||||||
|
|
||||||
def has_entity(self, entity_key):
|
|
||||||
# TODO: deprecate
|
|
||||||
"""Check whether the matcher has an entity.
|
|
||||||
|
|
||||||
entity_key (string or int): The entity key to check.
|
|
||||||
RETURNS (bool): Whether the matcher has the entity.
|
|
||||||
"""
|
"""
|
||||||
entity_key = self.normalize_entity_key(entity_key)
|
key = self._normalize_key(key)
|
||||||
return entity_key in self._entities
|
return key in self._patterns
|
||||||
|
|
||||||
def get_entity(self, entity_key):
|
def get(self, key, default=None):
|
||||||
# TODO: deprecate
|
"""Retrieve the pattern stored for an entity.
|
||||||
"""Retrieve the attributes stored for an entity.
|
|
||||||
|
|
||||||
entity_key (unicode or int): The entity to retrieve.
|
key (unicode or int): The key to retrieve.
|
||||||
RETURNS (dict): The entity attributes if present, otherwise None.
|
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
|
||||||
"""
|
"""
|
||||||
entity_key = self.normalize_entity_key(entity_key)
|
key = self._normalize_key(key)
|
||||||
if entity_key in self._entities:
|
if key not in self._patterns:
|
||||||
return self._entities[entity_key]
|
return default
|
||||||
else:
|
return (self._callbacks[key], self._patterns[key])
|
||||||
return None
|
|
||||||
|
|
||||||
def __call__(self, Doc doc, acceptor=None):
|
def pipe(self, docs, batch_size=1000, n_threads=2):
|
||||||
|
"""Match a stream of documents, yielding them in turn.
|
||||||
|
|
||||||
|
docs (iterable): A stream of documents.
|
||||||
|
batch_size (int): The number of documents to accumulate into a working set.
|
||||||
|
n_threads (int): The number of threads with which to work on the buffer
|
||||||
|
in parallel, if the `Matcher` implementation supports multi-threading.
|
||||||
|
YIELDS (Doc): Documents, in order.
|
||||||
|
"""
|
||||||
|
for doc in docs:
|
||||||
|
self(doc)
|
||||||
|
yield doc
|
||||||
|
|
||||||
|
def __call__(self, Doc doc):
|
||||||
"""Find all token sequences matching the supplied patterns on the `Doc`.
|
"""Find all token sequences matching the supplied patterns on the `Doc`.
|
||||||
|
|
||||||
doc (Doc): The document to match over.
|
doc (Doc): The document to match over.
|
||||||
RETURNS (list): A list of `(entity_key, label_id, start, end)` tuples,
|
RETURNS (list): A list of `(key, label_id, start, end)` tuples,
|
||||||
describing the matches. A match tuple describes a span
|
describing the matches. A match tuple describes a span
|
||||||
`doc[start:end]`. The `label_id` and `entity_key` are both integers.
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||||
"""
|
"""
|
||||||
if acceptor is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"acceptor keyword argument to Matcher deprecated. Specify acceptor "
|
|
||||||
"functions when you add patterns instead.")
|
|
||||||
cdef vector[StateC] partials
|
cdef vector[StateC] partials
|
||||||
cdef int n_partials = 0
|
cdef int n_partials = 0
|
||||||
cdef int q = 0
|
cdef int q = 0
|
||||||
|
@ -343,13 +330,7 @@ cdef class Matcher:
|
||||||
end = token_i+1
|
end = token_i+1
|
||||||
ent_id = state.second[1].attrs[0].value
|
ent_id = state.second[1].attrs[0].value
|
||||||
label = state.second[1].attrs[1].value
|
label = state.second[1].attrs[1].value
|
||||||
acceptor = self._acceptors.get(ent_id)
|
matches.append((ent_id, start, end))
|
||||||
if acceptor is None:
|
|
||||||
matches.append((ent_id, label, start, end))
|
|
||||||
else:
|
|
||||||
match = acceptor(doc, ent_id, label, start, end)
|
|
||||||
if match:
|
|
||||||
matches.append(match)
|
|
||||||
partials.resize(q)
|
partials.resize(q)
|
||||||
# Check whether we open any new patterns on this token
|
# Check whether we open any new patterns on this token
|
||||||
for pattern in self.patterns:
|
for pattern in self.patterns:
|
||||||
|
@ -374,13 +355,7 @@ cdef class Matcher:
|
||||||
end = token_i+1
|
end = token_i+1
|
||||||
ent_id = pattern[1].attrs[0].value
|
ent_id = pattern[1].attrs[0].value
|
||||||
label = pattern[1].attrs[1].value
|
label = pattern[1].attrs[1].value
|
||||||
acceptor = self._acceptors.get(ent_id)
|
matches.append((ent_id, start, end))
|
||||||
if acceptor is None:
|
|
||||||
matches.append((ent_id, label, start, end))
|
|
||||||
else:
|
|
||||||
match = acceptor(doc, ent_id, label, start, end)
|
|
||||||
if match:
|
|
||||||
matches.append(match)
|
|
||||||
# Look for open patterns that are actually satisfied
|
# Look for open patterns that are actually satisfied
|
||||||
for state in partials:
|
for state in partials:
|
||||||
while state.second.quantifier in (ZERO, ZERO_PLUS):
|
while state.second.quantifier in (ZERO, ZERO_PLUS):
|
||||||
|
@ -390,13 +365,7 @@ cdef class Matcher:
|
||||||
end = len(doc)
|
end = len(doc)
|
||||||
ent_id = state.second.attrs[0].value
|
ent_id = state.second.attrs[0].value
|
||||||
label = state.second.attrs[0].value
|
label = state.second.attrs[0].value
|
||||||
acceptor = self._acceptors.get(ent_id)
|
matches.append((ent_id, start, end))
|
||||||
if acceptor is None:
|
|
||||||
matches.append((ent_id, label, start, end))
|
|
||||||
else:
|
|
||||||
match = acceptor(doc, ent_id, label, start, end)
|
|
||||||
if match:
|
|
||||||
matches.append(match)
|
|
||||||
for i, (ent_id, label, start, end) in enumerate(matches):
|
for i, (ent_id, label, start, end) in enumerate(matches):
|
||||||
on_match = self._callbacks.get(ent_id)
|
on_match = self._callbacks.get(ent_id)
|
||||||
if on_match is not None:
|
if on_match is not None:
|
||||||
|
@ -404,18 +373,11 @@ cdef class Matcher:
|
||||||
# TODO: only return (match_id, start, end)
|
# TODO: only return (match_id, start, end)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
def pipe(self, docs, batch_size=1000, n_threads=2):
|
def _normalize_key(self, key):
|
||||||
"""Match a stream of documents, yielding them in turn.
|
if isinstance(key, basestring):
|
||||||
|
return self.vocab.strings[key]
|
||||||
docs (iterable): A stream of documents.
|
else:
|
||||||
batch_size (int): The number of documents to accumulate into a working set.
|
return key
|
||||||
n_threads (int): The number of threads with which to work on the buffer
|
|
||||||
in parallel, if the `Matcher` implementation supports multi-threading.
|
|
||||||
YIELDS (Doc): Documents, in order.
|
|
||||||
"""
|
|
||||||
for doc in docs:
|
|
||||||
self(doc)
|
|
||||||
yield doc
|
|
||||||
|
|
||||||
|
|
||||||
def get_bilou(length):
|
def get_bilou(length):
|
||||||
|
|
|
@ -116,8 +116,9 @@ p Match a stream of documents, yielding them in turn.
|
||||||
+tag method
|
+tag method
|
||||||
|
|
||||||
p
|
p
|
||||||
| Add one or more patterns to the matcher, along with a callback function
|
| Add a rule to the matcher, consisting of an ID key, one or more patterns, and
|
||||||
| to handle the matches. The callback function will receive the arguments
|
| a callback function to act on the matches.
|
||||||
|
| The callback function will receive the arguments
|
||||||
| #[code matcher], #[code doc], #[code i] and #[code matches].
|
| #[code matcher], #[code doc], #[code i] and #[code matches].
|
||||||
|
|
||||||
+aside-code("Example").
|
+aside-code("Example").
|
||||||
|
|
Loading…
Reference in New Issue
Block a user