mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Update Matcher API
This commit is contained in:
parent
8c9b3d5ad7
commit
ce9234f593
|
@ -87,7 +87,7 @@ ctypedef TokenPatternC* TokenPatternC_ptr
|
|||
ctypedef pair[int, TokenPatternC_ptr] StateC
|
||||
|
||||
|
||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label,
|
||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
|
||||
object token_specs) except NULL:
|
||||
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
||||
cdef int i
|
||||
|
@ -99,15 +99,21 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label,
|
|||
pattern[i].attrs[j].attr = attr
|
||||
pattern[i].attrs[j].value = value
|
||||
i = len(token_specs)
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(3, sizeof(AttrValueC))
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
|
||||
pattern[i].attrs[0].attr = ID
|
||||
pattern[i].attrs[0].value = entity_id
|
||||
pattern[i].attrs[1].attr = ENT_TYPE
|
||||
pattern[i].attrs[1].value = label
|
||||
pattern[i].nr_attr = 0
|
||||
return pattern
|
||||
|
||||
|
||||
cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0:
|
||||
while pattern.nr_attr != 0:
|
||||
pattern += 1
|
||||
id_attr = pattern[0].attrs[0]
|
||||
assert id_attr.attr == ID
|
||||
return id_attr.value
|
||||
|
||||
|
||||
cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil:
|
||||
for attr in pattern.attrs[:pattern.nr_attr]:
|
||||
if get_token_attr(token, attr.attr) != attr.value:
|
||||
|
@ -175,12 +181,11 @@ cdef class Matcher:
|
|||
cdef public object _callbacks
|
||||
cdef public object _acceptors
|
||||
|
||||
def __init__(self, vocab, patterns={}):
|
||||
def __init__(self, vocab):
|
||||
"""Create the Matcher.
|
||||
|
||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||
documents the matcher will operate on.
|
||||
patterns (dict): Patterns to add to the matcher.
|
||||
RETURNS (Matcher): The newly constructed object.
|
||||
"""
|
||||
self._patterns = {}
|
||||
|
@ -189,123 +194,105 @@ cdef class Matcher:
|
|||
self._callbacks = {}
|
||||
self.vocab = vocab
|
||||
self.mem = Pool()
|
||||
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
|
||||
self.add_entity(entity_key, attrs)
|
||||
for spec in specs:
|
||||
self.add_pattern(entity_key, spec, label=etype)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.vocab, self._patterns), None, None)
|
||||
|
||||
property n_patterns:
|
||||
def __get__(self): return self.patterns.size()
|
||||
def __len__(self):
|
||||
return len(self._patterns)
|
||||
|
||||
def add_entity(self, entity_key, attrs=None, if_exists='raise',
|
||||
acceptor=None, on_match=None):
|
||||
# TODO: replace with new Matcher.add()
|
||||
"""Add an entity to the matcher.
|
||||
def __contains__(self, key):
|
||||
return len(self._patterns)
|
||||
|
||||
entity_key (unicode or int): An ID for the entity.
|
||||
attrs (dict): Attributes to associate with the `Matcher`.
|
||||
if_exists (unicode): `'raise'`, `'ignore'` or `'update'`. Controls what
|
||||
happens if the entity ID already exists. Defaults to `'raise'`.
|
||||
acceptor (function): Callback function to filter matches of the entity.
|
||||
on_match (function): Callback function to act on matches of the entity.
|
||||
def add(self, key, on_match, *patterns):
|
||||
"""Add a match-rule to the matcher.
|
||||
|
||||
A match-rule consists of: an ID key, an on_match callback,
|
||||
and one or more patterns. If the key exists, the patterns
|
||||
are appended to the previous ones, and the previous on_match
|
||||
callback is replaced.
|
||||
|
||||
The on_match callback will receive the arguments
|
||||
(matcher, doc, i, matches). Note that if no `on_match`
|
||||
callback is specified, the document will not be modified.
|
||||
|
||||
A pattern consists of one or more token_specs,
|
||||
where a token_spec is a dictionary mapping
|
||||
attribute IDs to values. Token descriptors can also
|
||||
include quantifiers. There are currently important
|
||||
known problems with the quantifiers --- see the docs.
|
||||
"""
|
||||
if if_exists not in ('raise', 'ignore', 'update'):
|
||||
raise ValueError(
|
||||
"Unexpected value for if_exists: %s.\n"
|
||||
"Expected one of: ['raise', 'ignore', 'update']" % if_exists)
|
||||
if attrs is None:
|
||||
attrs = {}
|
||||
entity_key = self.normalize_entity_key(entity_key)
|
||||
if self.has_entity(entity_key):
|
||||
if if_exists == 'raise':
|
||||
raise KeyError(
|
||||
"Tried to add entity %s. Entity exists, and if_exists='raise'.\n"
|
||||
"Set if_exists='ignore' or if_exists='update', or check with "
|
||||
"matcher.has_entity()")
|
||||
elif if_exists == 'ignore':
|
||||
return
|
||||
self._entities[entity_key] = dict(attrs)
|
||||
self._patterns.setdefault(entity_key, [])
|
||||
self._acceptors[entity_key] = acceptor
|
||||
self._callbacks[entity_key] = on_match
|
||||
|
||||
def add_pattern(self, entity_key, token_specs, label=""):
|
||||
# TODO: replace with new Matcher.add()
|
||||
"""Add a pattern to the matcher.
|
||||
|
||||
entity_key (unicode): An ID for the entity.
|
||||
token_specs (list): Description of the pattern to be matched.
|
||||
label (unicode): Label to assign to the matched pattern. Defaults to `""`.
|
||||
"""
|
||||
token_specs = list(token_specs)
|
||||
if len(token_specs) == 0:
|
||||
for pattern in patterns:
|
||||
if len(pattern) == 0:
|
||||
msg = ("Cannot add pattern for zero tokens to matcher.\n"
|
||||
"entity_key: {entity_key}\n"
|
||||
"label: {label}")
|
||||
raise ValueError(msg.format(entity_key=entity_key, label=label))
|
||||
entity_key = self.normalize_entity_key(entity_key)
|
||||
if not self.has_entity(entity_key):
|
||||
self.add_entity(entity_key)
|
||||
if isinstance(label, basestring):
|
||||
label = self.vocab.strings[label]
|
||||
elif label is None:
|
||||
label = 0
|
||||
spec = _convert_strings(token_specs, self.vocab.strings)
|
||||
"key: {key}\n")
|
||||
raise ValueError(msg.format(key=key))
|
||||
key = self._normalize_key(key)
|
||||
self._patterns.setdefault(key, [])
|
||||
self._callbacks[key] = on_match
|
||||
|
||||
self.patterns.push_back(init_pattern(self.mem, entity_key, label, spec))
|
||||
self._patterns[entity_key].append((label, token_specs))
|
||||
for pattern in patterns:
|
||||
specs = _convert_strings(pattern, self.vocab.strings)
|
||||
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
||||
self._patterns[key].append(specs)
|
||||
|
||||
def add(self, entity_key, label, attrs, specs, acceptor=None, on_match=None):
|
||||
# TODO: replace with new Matcher.add()
|
||||
self.add_entity(entity_key, attrs=attrs, if_exists='update',
|
||||
acceptor=acceptor, on_match=on_match)
|
||||
for spec in specs:
|
||||
self.add_pattern(entity_key, spec, label=label)
|
||||
def remove(self, key):
|
||||
"""Remove a rule from the matcher.
|
||||
|
||||
def normalize_entity_key(self, entity_key):
|
||||
if isinstance(entity_key, basestring):
|
||||
return self.vocab.strings[entity_key]
|
||||
else:
|
||||
return entity_key
|
||||
|
||||
def has_entity(self, entity_key):
|
||||
# TODO: deprecate
|
||||
"""Check whether the matcher has an entity.
|
||||
|
||||
entity_key (string or int): The entity key to check.
|
||||
RETURNS (bool): Whether the matcher has the entity.
|
||||
A KeyError is raised if the key does not exist.
|
||||
"""
|
||||
entity_key = self.normalize_entity_key(entity_key)
|
||||
return entity_key in self._entities
|
||||
|
||||
def get_entity(self, entity_key):
|
||||
# TODO: deprecate
|
||||
"""Retrieve the attributes stored for an entity.
|
||||
|
||||
entity_key (unicode or int): The entity to retrieve.
|
||||
RETURNS (dict): The entity attributes if present, otherwise None.
|
||||
"""
|
||||
entity_key = self.normalize_entity_key(entity_key)
|
||||
if entity_key in self._entities:
|
||||
return self._entities[entity_key]
|
||||
key = self._normalize_key(key)
|
||||
self._patterns.pop(key)
|
||||
self._callbacks.pop(key)
|
||||
cdef int i = 0
|
||||
while i < self.patterns.size():
|
||||
pattern_key = get_pattern_key(self.patterns.at(i))
|
||||
if pattern_key == key:
|
||||
self.patterns.erase(self.patterns.begin()+i)
|
||||
else:
|
||||
return None
|
||||
i += 1
|
||||
|
||||
def __call__(self, Doc doc, acceptor=None):
|
||||
def has_key(self, key):
|
||||
"""Check whether the matcher has a rule with a given key.
|
||||
|
||||
key (string or int): The key to check.
|
||||
RETURNS (bool): Whether the matcher has the rule.
|
||||
"""
|
||||
key = self._normalize_key(key)
|
||||
return key in self._patterns
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Retrieve the pattern stored for an entity.
|
||||
|
||||
key (unicode or int): The key to retrieve.
|
||||
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
|
||||
"""
|
||||
key = self._normalize_key(key)
|
||||
if key not in self._patterns:
|
||||
return default
|
||||
return (self._callbacks[key], self._patterns[key])
|
||||
|
||||
def pipe(self, docs, batch_size=1000, n_threads=2):
|
||||
"""Match a stream of documents, yielding them in turn.
|
||||
|
||||
docs (iterable): A stream of documents.
|
||||
batch_size (int): The number of documents to accumulate into a working set.
|
||||
n_threads (int): The number of threads with which to work on the buffer
|
||||
in parallel, if the `Matcher` implementation supports multi-threading.
|
||||
YIELDS (Doc): Documents, in order.
|
||||
"""
|
||||
for doc in docs:
|
||||
self(doc)
|
||||
yield doc
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
"""Find all token sequences matching the supplied patterns on the `Doc`.
|
||||
|
||||
doc (Doc): The document to match over.
|
||||
RETURNS (list): A list of `(entity_key, label_id, start, end)` tuples,
|
||||
RETURNS (list): A list of `(key, label_id, start, end)` tuples,
|
||||
describing the matches. A match tuple describes a span
|
||||
`doc[start:end]`. The `label_id` and `entity_key` are both integers.
|
||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||
"""
|
||||
if acceptor is not None:
|
||||
raise ValueError(
|
||||
"acceptor keyword argument to Matcher deprecated. Specify acceptor "
|
||||
"functions when you add patterns instead.")
|
||||
cdef vector[StateC] partials
|
||||
cdef int n_partials = 0
|
||||
cdef int q = 0
|
||||
|
@ -343,13 +330,7 @@ cdef class Matcher:
|
|||
end = token_i+1
|
||||
ent_id = state.second[1].attrs[0].value
|
||||
label = state.second[1].attrs[1].value
|
||||
acceptor = self._acceptors.get(ent_id)
|
||||
if acceptor is None:
|
||||
matches.append((ent_id, label, start, end))
|
||||
else:
|
||||
match = acceptor(doc, ent_id, label, start, end)
|
||||
if match:
|
||||
matches.append(match)
|
||||
matches.append((ent_id, start, end))
|
||||
partials.resize(q)
|
||||
# Check whether we open any new patterns on this token
|
||||
for pattern in self.patterns:
|
||||
|
@ -374,13 +355,7 @@ cdef class Matcher:
|
|||
end = token_i+1
|
||||
ent_id = pattern[1].attrs[0].value
|
||||
label = pattern[1].attrs[1].value
|
||||
acceptor = self._acceptors.get(ent_id)
|
||||
if acceptor is None:
|
||||
matches.append((ent_id, label, start, end))
|
||||
else:
|
||||
match = acceptor(doc, ent_id, label, start, end)
|
||||
if match:
|
||||
matches.append(match)
|
||||
matches.append((ent_id, start, end))
|
||||
# Look for open patterns that are actually satisfied
|
||||
for state in partials:
|
||||
while state.second.quantifier in (ZERO, ZERO_PLUS):
|
||||
|
@ -390,13 +365,7 @@ cdef class Matcher:
|
|||
end = len(doc)
|
||||
ent_id = state.second.attrs[0].value
|
||||
label = state.second.attrs[0].value
|
||||
acceptor = self._acceptors.get(ent_id)
|
||||
if acceptor is None:
|
||||
matches.append((ent_id, label, start, end))
|
||||
else:
|
||||
match = acceptor(doc, ent_id, label, start, end)
|
||||
if match:
|
||||
matches.append(match)
|
||||
matches.append((ent_id, start, end))
|
||||
for i, (ent_id, label, start, end) in enumerate(matches):
|
||||
on_match = self._callbacks.get(ent_id)
|
||||
if on_match is not None:
|
||||
|
@ -404,18 +373,11 @@ cdef class Matcher:
|
|||
# TODO: only return (match_id, start, end)
|
||||
return matches
|
||||
|
||||
def pipe(self, docs, batch_size=1000, n_threads=2):
|
||||
"""Match a stream of documents, yielding them in turn.
|
||||
|
||||
docs (iterable): A stream of documents.
|
||||
batch_size (int): The number of documents to accumulate into a working set.
|
||||
n_threads (int): The number of threads with which to work on the buffer
|
||||
in parallel, if the `Matcher` implementation supports multi-threading.
|
||||
YIELDS (Doc): Documents, in order.
|
||||
"""
|
||||
for doc in docs:
|
||||
self(doc)
|
||||
yield doc
|
||||
def _normalize_key(self, key):
|
||||
if isinstance(key, basestring):
|
||||
return self.vocab.strings[key]
|
||||
else:
|
||||
return key
|
||||
|
||||
|
||||
def get_bilou(length):
|
||||
|
|
|
@ -116,8 +116,9 @@ p Match a stream of documents, yielding them in turn.
|
|||
+tag method
|
||||
|
||||
p
|
||||
| Add one or more patterns to the matcher, along with a callback function
|
||||
| to handle the matches. The callback function will receive the arguments
|
||||
| Add a rule to the matcher, consisting of an ID key, one or more patterns, and
|
||||
| a callback function to act on the matches.
|
||||
| The callback function will receive the arguments
|
||||
| #[code matcher], #[code doc], #[code i] and #[code matches].
|
||||
|
||||
+aside-code("Example").
|
||||
|
|
Loading…
Reference in New Issue
Block a user