2015-10-08 18:00:45 +03:00
|
|
|
|
# cython: profile=True
|
2016-10-28 18:42:00 +03:00
|
|
|
|
# cython: infer_types=True
|
2017-04-15 13:05:47 +03:00
|
|
|
|
# coding: utf8
|
2015-10-08 18:00:45 +03:00
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
2017-04-15 13:13:34 +03:00
|
|
|
|
import ujson
|
|
|
|
|
|
2015-08-05 02:05:54 +03:00
|
|
|
|
from .typedefs cimport attr_t
|
2015-10-08 18:00:45 +03:00
|
|
|
|
from .typedefs cimport hash_t
|
2015-08-05 02:05:54 +03:00
|
|
|
|
from .attrs cimport attr_id_t
|
2017-03-21 23:08:54 +03:00
|
|
|
|
from .structs cimport TokenC
|
2015-08-04 16:55:28 +03:00
|
|
|
|
|
2015-08-05 02:05:54 +03:00
|
|
|
|
from cymem.cymem cimport Pool
|
2015-10-08 18:00:45 +03:00
|
|
|
|
from preshed.maps cimport PreshMap
|
2015-08-05 02:05:54 +03:00
|
|
|
|
from libcpp.vector cimport vector
|
2016-09-21 15:54:55 +03:00
|
|
|
|
from libcpp.pair cimport pair
|
2015-10-08 18:00:45 +03:00
|
|
|
|
from murmurhash.mrmr cimport hash64
|
2016-09-21 15:54:55 +03:00
|
|
|
|
from libc.stdint cimport int32_t
|
2015-08-04 16:55:28 +03:00
|
|
|
|
|
2017-03-21 23:08:54 +03:00
|
|
|
|
from .attrs cimport ID, ENT_TYPE
|
2016-04-14 11:37:39 +03:00
|
|
|
|
from . import attrs
|
2015-08-05 02:05:54 +03:00
|
|
|
|
from .tokens.doc cimport get_token_attr
|
|
|
|
|
from .tokens.doc cimport Doc
|
|
|
|
|
from .vocab cimport Vocab
|
2015-08-04 16:55:28 +03:00
|
|
|
|
|
2015-10-08 18:00:45 +03:00
|
|
|
|
from .attrs import FLAG61 as U_ENT
|
|
|
|
|
|
|
|
|
|
from .attrs import FLAG60 as B2_ENT
|
|
|
|
|
from .attrs import FLAG59 as B3_ENT
|
|
|
|
|
from .attrs import FLAG58 as B4_ENT
|
|
|
|
|
from .attrs import FLAG57 as B5_ENT
|
|
|
|
|
from .attrs import FLAG56 as B6_ENT
|
|
|
|
|
from .attrs import FLAG55 as B7_ENT
|
|
|
|
|
from .attrs import FLAG54 as B8_ENT
|
|
|
|
|
from .attrs import FLAG53 as B9_ENT
|
|
|
|
|
from .attrs import FLAG52 as B10_ENT
|
|
|
|
|
|
|
|
|
|
from .attrs import FLAG51 as I3_ENT
|
|
|
|
|
from .attrs import FLAG50 as I4_ENT
|
|
|
|
|
from .attrs import FLAG49 as I5_ENT
|
|
|
|
|
from .attrs import FLAG48 as I6_ENT
|
|
|
|
|
from .attrs import FLAG47 as I7_ENT
|
|
|
|
|
from .attrs import FLAG46 as I8_ENT
|
|
|
|
|
from .attrs import FLAG45 as I9_ENT
|
|
|
|
|
from .attrs import FLAG44 as I10_ENT
|
|
|
|
|
|
|
|
|
|
from .attrs import FLAG43 as L2_ENT
|
|
|
|
|
from .attrs import FLAG42 as L3_ENT
|
|
|
|
|
from .attrs import FLAG41 as L4_ENT
|
|
|
|
|
from .attrs import FLAG40 as L5_ENT
|
|
|
|
|
from .attrs import FLAG39 as L6_ENT
|
|
|
|
|
from .attrs import FLAG38 as L7_ENT
|
|
|
|
|
from .attrs import FLAG37 as L8_ENT
|
|
|
|
|
from .attrs import FLAG36 as L9_ENT
|
|
|
|
|
from .attrs import FLAG35 as L10_ENT
|
|
|
|
|
|
|
|
|
|
|
2016-09-21 15:54:55 +03:00
|
|
|
|
cpdef enum quantifier_t:
|
|
|
|
|
_META
|
|
|
|
|
ONE
|
|
|
|
|
ZERO
|
|
|
|
|
ZERO_ONE
|
|
|
|
|
ZERO_PLUS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cdef enum action_t:
|
|
|
|
|
REJECT
|
|
|
|
|
ADVANCE
|
|
|
|
|
REPEAT
|
|
|
|
|
ACCEPT
|
|
|
|
|
ADVANCE_ZERO
|
|
|
|
|
PANIC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cdef struct AttrValueC:
|
2015-08-05 02:05:54 +03:00
|
|
|
|
attr_id_t attr
|
|
|
|
|
attr_t value
|
2015-08-04 16:55:28 +03:00
|
|
|
|
|
|
|
|
|
|
2016-09-21 15:54:55 +03:00
|
|
|
|
cdef struct TokenPatternC:
|
|
|
|
|
AttrValueC* attrs
|
|
|
|
|
int32_t nr_attr
|
|
|
|
|
quantifier_t quantifier
|
2015-08-04 16:55:28 +03:00
|
|
|
|
|
|
|
|
|
|
2016-09-21 15:54:55 +03:00
|
|
|
|
ctypedef TokenPatternC* TokenPatternC_ptr
|
|
|
|
|
ctypedef pair[int, TokenPatternC_ptr] StateC
|
|
|
|
|
|
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
|
2016-10-17 16:23:31 +03:00
|
|
|
|
object token_specs) except NULL:
|
2016-09-21 15:54:55 +03:00
|
|
|
|
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
2015-08-05 02:05:54 +03:00
|
|
|
|
cdef int i
|
2016-09-21 15:54:55 +03:00
|
|
|
|
for i, (quantifier, spec) in enumerate(token_specs):
|
|
|
|
|
pattern[i].quantifier = quantifier
|
|
|
|
|
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
|
|
|
|
pattern[i].nr_attr = len(spec)
|
2015-08-05 02:05:54 +03:00
|
|
|
|
for j, (attr, value) in enumerate(spec):
|
2016-09-21 15:54:55 +03:00
|
|
|
|
pattern[i].attrs[j].attr = attr
|
|
|
|
|
pattern[i].attrs[j].value = value
|
2015-08-05 02:05:54 +03:00
|
|
|
|
i = len(token_specs)
|
2017-05-20 14:54:53 +03:00
|
|
|
|
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
|
2016-09-21 15:54:55 +03:00
|
|
|
|
pattern[i].attrs[0].attr = ID
|
|
|
|
|
pattern[i].attrs[0].value = entity_id
|
|
|
|
|
pattern[i].nr_attr = 0
|
2015-08-05 02:05:54 +03:00
|
|
|
|
return pattern
|
|
|
|
|
|
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0:
|
|
|
|
|
while pattern.nr_attr != 0:
|
|
|
|
|
pattern += 1
|
|
|
|
|
id_attr = pattern[0].attrs[0]
|
|
|
|
|
assert id_attr.attr == ID
|
|
|
|
|
return id_attr.value
|
|
|
|
|
|
|
|
|
|
|
2016-09-21 15:54:55 +03:00
|
|
|
|
cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil:
|
|
|
|
|
for attr in pattern.attrs[:pattern.nr_attr]:
|
|
|
|
|
if get_token_attr(token, attr.attr) != attr.value:
|
|
|
|
|
if pattern.quantifier == ONE:
|
|
|
|
|
return REJECT
|
|
|
|
|
elif pattern.quantifier == ZERO:
|
|
|
|
|
return ACCEPT if (pattern+1).nr_attr == 0 else ADVANCE
|
|
|
|
|
elif pattern.quantifier in (ZERO_ONE, ZERO_PLUS):
|
|
|
|
|
return ACCEPT if (pattern+1).nr_attr == 0 else ADVANCE_ZERO
|
|
|
|
|
else:
|
|
|
|
|
return PANIC
|
|
|
|
|
if pattern.quantifier == ZERO:
|
|
|
|
|
return REJECT
|
|
|
|
|
elif pattern.quantifier in (ONE, ZERO_ONE):
|
|
|
|
|
return ACCEPT if (pattern+1).nr_attr == 0 else ADVANCE
|
|
|
|
|
elif pattern.quantifier == ZERO_PLUS:
|
|
|
|
|
return REPEAT
|
|
|
|
|
else:
|
|
|
|
|
return PANIC
|
2015-08-05 02:05:54 +03:00
|
|
|
|
|
|
|
|
|
|
2015-08-06 15:33:21 +03:00
|
|
|
|
def _convert_strings(token_specs, string_store):
|
2016-09-21 15:54:55 +03:00
|
|
|
|
# Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
|
|
|
|
|
operators = {'!': (ZERO,), '*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS),
|
2017-02-24 16:27:02 +03:00
|
|
|
|
'?': (ZERO_ONE,), '1': (ONE,)}
|
2016-09-21 15:54:55 +03:00
|
|
|
|
tokens = []
|
|
|
|
|
op = ONE
|
2015-08-06 15:33:21 +03:00
|
|
|
|
for spec in token_specs:
|
2016-11-08 19:14:26 +03:00
|
|
|
|
token = []
|
2016-09-21 15:54:55 +03:00
|
|
|
|
ops = (ONE,)
|
2015-08-06 15:33:21 +03:00
|
|
|
|
for attr, value in spec.items():
|
2016-09-21 15:54:55 +03:00
|
|
|
|
if isinstance(attr, basestring) and attr.upper() == 'OP':
|
|
|
|
|
if value in operators:
|
|
|
|
|
ops = operators[value]
|
|
|
|
|
else:
|
|
|
|
|
raise KeyError(
|
2017-02-24 16:27:02 +03:00
|
|
|
|
"Unknown operator '%s'. Options: %s" % (value, ', '.join(operators.keys())))
|
2015-08-06 15:33:21 +03:00
|
|
|
|
if isinstance(attr, basestring):
|
2016-04-15 16:46:31 +03:00
|
|
|
|
attr = attrs.IDS.get(attr.upper())
|
2015-08-06 15:33:21 +03:00
|
|
|
|
if isinstance(value, basestring):
|
2016-09-30 21:20:13 +03:00
|
|
|
|
value = string_store[value]
|
2015-09-06 18:53:12 +03:00
|
|
|
|
if isinstance(value, bool):
|
|
|
|
|
value = int(value)
|
2016-04-14 11:37:39 +03:00
|
|
|
|
if attr is not None:
|
2016-09-21 15:54:55 +03:00
|
|
|
|
token.append((attr, value))
|
|
|
|
|
for op in ops:
|
|
|
|
|
tokens.append((op, token))
|
|
|
|
|
return tokens
|
2015-10-08 18:00:45 +03:00
|
|
|
|
|
|
|
|
|
|
2017-03-31 14:58:59 +03:00
|
|
|
|
def merge_phrase(matcher, doc, i, matches):
|
2017-05-19 22:47:06 +03:00
|
|
|
|
"""Callback to merge a phrase on match."""
|
2017-03-31 14:58:59 +03:00
|
|
|
|
ent_id, label, start, end = matches[i]
|
2017-04-15 13:05:47 +03:00
|
|
|
|
span = doc[start : end]
|
2017-03-31 14:58:59 +03:00
|
|
|
|
span.merge(ent_type=label, ent_id=ent_id)
|
|
|
|
|
|
|
|
|
|
|
2015-08-05 02:05:54 +03:00
|
|
|
|
cdef class Matcher:
|
2017-05-19 22:47:06 +03:00
|
|
|
|
"""Match sequences of tokens, based on pattern rules."""
|
2015-08-05 02:05:54 +03:00
|
|
|
|
cdef Pool mem
|
2016-09-21 15:54:55 +03:00
|
|
|
|
cdef vector[TokenPatternC*] patterns
|
2015-08-26 20:17:02 +03:00
|
|
|
|
cdef readonly Vocab vocab
|
2016-09-24 12:20:42 +03:00
|
|
|
|
cdef public object _patterns
|
2016-10-17 16:23:31 +03:00
|
|
|
|
cdef public object _entities
|
|
|
|
|
cdef public object _callbacks
|
|
|
|
|
cdef public object _acceptors
|
2016-11-08 19:14:26 +03:00
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
def __init__(self, vocab):
|
2017-05-19 22:47:06 +03:00
|
|
|
|
"""Create the Matcher.
|
|
|
|
|
|
|
|
|
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
|
|
|
|
documents the matcher will operate on.
|
|
|
|
|
RETURNS (Matcher): The newly constructed object.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
"""
|
2016-10-17 16:23:31 +03:00
|
|
|
|
self._patterns = {}
|
|
|
|
|
self._entities = {}
|
|
|
|
|
self._acceptors = {}
|
|
|
|
|
self._callbacks = {}
|
2015-10-12 11:33:11 +03:00
|
|
|
|
self.vocab = vocab
|
|
|
|
|
self.mem = Pool()
|
|
|
|
|
|
2016-10-17 17:49:43 +03:00
|
|
|
|
def __reduce__(self):
|
|
|
|
|
return (self.__class__, (self.vocab, self._patterns), None, None)
|
2016-11-08 19:14:26 +03:00
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
def __len__(self):
|
2017-05-20 15:32:34 +03:00
|
|
|
|
"""Get the number of rules added to the matcher. Note that this only
|
|
|
|
|
returns the number of rules (identical with the number of IDs), not the
|
|
|
|
|
number of individual patterns.
|
2017-05-20 15:26:10 +03:00
|
|
|
|
|
|
|
|
|
RETURNS (int): The number of rules.
|
|
|
|
|
"""
|
2017-05-20 14:54:53 +03:00
|
|
|
|
return len(self._patterns)
|
|
|
|
|
|
|
|
|
|
def __contains__(self, key):
|
2017-05-20 15:26:10 +03:00
|
|
|
|
"""Check whether the matcher contains rules for a match ID.
|
|
|
|
|
|
|
|
|
|
key (unicode): The match ID.
|
|
|
|
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
|
|
|
|
"""
|
2017-05-20 14:54:53 +03:00
|
|
|
|
return len(self._patterns)
|
|
|
|
|
|
2017-05-23 12:37:40 +03:00
|
|
|
|
def add(self, key, on_match, *patterns):
|
2017-05-20 14:54:53 +03:00
|
|
|
|
"""Add a match-rule to the matcher.
|
2017-05-20 15:05:07 +03:00
|
|
|
|
A match-rule consists of: an ID key, an on_match callback, and one or
|
|
|
|
|
more patterns. If the key exists, the patterns are appended to the
|
|
|
|
|
previous ones, and the previous on_match callback is replaced. The
|
2017-05-20 15:26:10 +03:00
|
|
|
|
`on_match` callback will receive the arguments `(matcher, doc, i,
|
|
|
|
|
matches)`. You can also set `on_match` to `None` to not perform any
|
|
|
|
|
actions. A pattern consists of one or more `token_specs`, where a
|
2017-05-20 15:05:07 +03:00
|
|
|
|
`token_spec` is a dictionary mapping attribute IDs to values. Token
|
|
|
|
|
descriptors can also include quantifiers. There are currently important
|
|
|
|
|
known problems with the quantifiers – see the docs.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
"""
|
2017-05-20 14:54:53 +03:00
|
|
|
|
for pattern in patterns:
|
|
|
|
|
if len(pattern) == 0:
|
|
|
|
|
msg = ("Cannot add pattern for zero tokens to matcher.\n"
|
|
|
|
|
"key: {key}\n")
|
|
|
|
|
raise ValueError(msg.format(key=key))
|
|
|
|
|
key = self._normalize_key(key)
|
|
|
|
|
self._patterns.setdefault(key, [])
|
|
|
|
|
self._callbacks[key] = on_match
|
|
|
|
|
|
|
|
|
|
for pattern in patterns:
|
|
|
|
|
specs = _convert_strings(pattern, self.vocab.strings)
|
|
|
|
|
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
|
|
|
|
self._patterns[key].append(specs)
|
|
|
|
|
|
|
|
|
|
def remove(self, key):
|
2017-05-20 15:26:10 +03:00
|
|
|
|
"""Remove a rule from the matcher. A KeyError is raised if the key does
|
|
|
|
|
not exist.
|
|
|
|
|
|
|
|
|
|
key (unicode): The ID of the match rule.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
"""
|
2017-05-20 14:54:53 +03:00
|
|
|
|
key = self._normalize_key(key)
|
|
|
|
|
self._patterns.pop(key)
|
|
|
|
|
self._callbacks.pop(key)
|
|
|
|
|
cdef int i = 0
|
|
|
|
|
while i < self.patterns.size():
|
|
|
|
|
pattern_key = get_pattern_key(self.patterns.at(i))
|
|
|
|
|
if pattern_key == key:
|
|
|
|
|
self.patterns.erase(self.patterns.begin()+i)
|
|
|
|
|
else:
|
|
|
|
|
i += 1
|
2016-10-17 16:23:31 +03:00
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
def has_key(self, key):
|
|
|
|
|
"""Check whether the matcher has a rule with a given key.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
key (string or int): The key to check.
|
|
|
|
|
RETURNS (bool): Whether the matcher has the rule.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
"""
|
2017-05-20 14:54:53 +03:00
|
|
|
|
key = self._normalize_key(key)
|
|
|
|
|
return key in self._patterns
|
2016-10-17 16:23:31 +03:00
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
def get(self, key, default=None):
|
2017-05-20 15:43:10 +03:00
|
|
|
|
"""Retrieve the pattern stored for a key.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
key (unicode or int): The key to retrieve.
|
|
|
|
|
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
"""
|
2017-05-20 14:54:53 +03:00
|
|
|
|
key = self._normalize_key(key)
|
|
|
|
|
if key not in self._patterns:
|
|
|
|
|
return default
|
|
|
|
|
return (self._callbacks[key], self._patterns[key])
|
|
|
|
|
|
|
|
|
|
def pipe(self, docs, batch_size=1000, n_threads=2):
|
|
|
|
|
"""Match a stream of documents, yielding them in turn.
|
|
|
|
|
|
|
|
|
|
docs (iterable): A stream of documents.
|
|
|
|
|
batch_size (int): The number of documents to accumulate into a working set.
|
|
|
|
|
n_threads (int): The number of threads with which to work on the buffer
|
|
|
|
|
in parallel, if the `Matcher` implementation supports multi-threading.
|
|
|
|
|
YIELDS (Doc): Documents, in order.
|
|
|
|
|
"""
|
|
|
|
|
for doc in docs:
|
|
|
|
|
self(doc)
|
|
|
|
|
yield doc
|
2015-08-05 02:05:54 +03:00
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
def __call__(self, Doc doc):
|
2017-05-19 22:47:06 +03:00
|
|
|
|
"""Find all token sequences matching the supplied patterns on the `Doc`.
|
|
|
|
|
|
|
|
|
|
doc (Doc): The document to match over.
|
2017-05-20 14:54:53 +03:00
|
|
|
|
RETURNS (list): A list of `(key, label_id, start, end)` tuples,
|
2017-05-19 22:47:06 +03:00
|
|
|
|
describing the matches. A match tuple describes a span
|
2017-05-20 14:54:53 +03:00
|
|
|
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
2016-11-01 14:25:36 +03:00
|
|
|
|
"""
|
2016-09-21 15:54:55 +03:00
|
|
|
|
cdef vector[StateC] partials
|
2015-08-05 02:05:54 +03:00
|
|
|
|
cdef int n_partials = 0
|
|
|
|
|
cdef int q = 0
|
|
|
|
|
cdef int i, token_i
|
|
|
|
|
cdef const TokenC* token
|
2016-09-21 15:54:55 +03:00
|
|
|
|
cdef StateC state
|
2015-08-04 16:55:28 +03:00
|
|
|
|
matches = []
|
2015-08-05 02:05:54 +03:00
|
|
|
|
for token_i in range(doc.length):
|
2015-11-03 16:15:14 +03:00
|
|
|
|
token = &doc.c[token_i]
|
2015-08-05 02:05:54 +03:00
|
|
|
|
q = 0
|
2015-10-08 18:00:45 +03:00
|
|
|
|
# Go over the open matches, extending or finalizing if able. Otherwise,
|
|
|
|
|
# we over-write them (q doesn't advance)
|
2015-10-18 09:20:50 +03:00
|
|
|
|
for state in partials:
|
2016-09-21 15:54:55 +03:00
|
|
|
|
action = get_action(state.second, token)
|
2016-10-28 18:42:00 +03:00
|
|
|
|
if action == PANIC:
|
|
|
|
|
raise Exception("Error selecting action in matcher")
|
2016-09-21 15:54:55 +03:00
|
|
|
|
while action == ADVANCE_ZERO:
|
|
|
|
|
state.second += 1
|
|
|
|
|
action = get_action(state.second, token)
|
|
|
|
|
if action == REPEAT:
|
|
|
|
|
# Leave the state in the queue, and advance to next slot
|
|
|
|
|
# (i.e. we don't overwrite -- we want to greedily match more
|
|
|
|
|
# pattern.
|
|
|
|
|
q += 1
|
|
|
|
|
elif action == REJECT:
|
|
|
|
|
pass
|
|
|
|
|
elif action == ADVANCE:
|
2016-10-28 18:42:00 +03:00
|
|
|
|
partials[q] = state
|
2016-09-21 15:54:55 +03:00
|
|
|
|
partials[q].second += 1
|
|
|
|
|
q += 1
|
|
|
|
|
elif action == ACCEPT:
|
|
|
|
|
# TODO: What to do about patterns starting with ZERO? Need to
|
|
|
|
|
# adjust the start position.
|
2016-10-17 17:49:51 +03:00
|
|
|
|
start = state.first
|
|
|
|
|
end = token_i+1
|
2016-09-21 15:54:55 +03:00
|
|
|
|
ent_id = state.second[1].attrs[0].value
|
|
|
|
|
label = state.second[1].attrs[1].value
|
2017-05-20 14:54:53 +03:00
|
|
|
|
matches.append((ent_id, start, end))
|
2015-08-05 02:05:54 +03:00
|
|
|
|
partials.resize(q)
|
2015-10-08 18:00:45 +03:00
|
|
|
|
# Check whether we open any new patterns on this token
|
2016-09-21 15:54:55 +03:00
|
|
|
|
for pattern in self.patterns:
|
|
|
|
|
action = get_action(pattern, token)
|
2016-10-28 18:42:00 +03:00
|
|
|
|
if action == PANIC:
|
|
|
|
|
raise Exception("Error selecting action in matcher")
|
2016-09-21 15:54:55 +03:00
|
|
|
|
while action == ADVANCE_ZERO:
|
|
|
|
|
pattern += 1
|
|
|
|
|
action = get_action(pattern, token)
|
|
|
|
|
if action == REPEAT:
|
|
|
|
|
state.first = token_i
|
|
|
|
|
state.second = pattern
|
|
|
|
|
partials.push_back(state)
|
|
|
|
|
elif action == ADVANCE:
|
|
|
|
|
# TODO: What to do about patterns starting with ZERO? Need to
|
|
|
|
|
# adjust the start position.
|
|
|
|
|
state.first = token_i
|
|
|
|
|
state.second = pattern + 1
|
|
|
|
|
partials.push_back(state)
|
|
|
|
|
elif action == ACCEPT:
|
2016-10-17 17:49:51 +03:00
|
|
|
|
start = token_i
|
|
|
|
|
end = token_i+1
|
2016-09-21 15:54:55 +03:00
|
|
|
|
ent_id = pattern[1].attrs[0].value
|
|
|
|
|
label = pattern[1].attrs[1].value
|
2017-05-20 14:54:53 +03:00
|
|
|
|
matches.append((ent_id, start, end))
|
2017-02-24 16:27:02 +03:00
|
|
|
|
# Look for open patterns that are actually satisfied
|
|
|
|
|
for state in partials:
|
|
|
|
|
while state.second.quantifier in (ZERO, ZERO_PLUS):
|
|
|
|
|
state.second += 1
|
|
|
|
|
if state.second.nr_attr == 0:
|
|
|
|
|
start = state.first
|
|
|
|
|
end = len(doc)
|
|
|
|
|
ent_id = state.second.attrs[0].value
|
|
|
|
|
label = state.second.attrs[0].value
|
2017-05-20 14:54:53 +03:00
|
|
|
|
matches.append((ent_id, start, end))
|
2017-05-22 13:59:50 +03:00
|
|
|
|
for i, (ent_id, start, end) in enumerate(matches):
|
2016-10-17 16:23:31 +03:00
|
|
|
|
on_match = self._callbacks.get(ent_id)
|
|
|
|
|
if on_match is not None:
|
|
|
|
|
on_match(self, doc, i, matches)
|
2017-05-20 02:38:04 +03:00
|
|
|
|
# TODO: only return (match_id, start, end)
|
2015-08-04 16:55:28 +03:00
|
|
|
|
return matches
|
2015-10-08 18:00:45 +03:00
|
|
|
|
|
2017-05-20 14:54:53 +03:00
|
|
|
|
def _normalize_key(self, key):
|
|
|
|
|
if isinstance(key, basestring):
|
|
|
|
|
return self.vocab.strings[key]
|
|
|
|
|
else:
|
|
|
|
|
return key
|
2016-02-03 04:04:55 +03:00
|
|
|
|
|
2015-10-08 18:00:45 +03:00
|
|
|
|
|
2016-10-17 16:23:31 +03:00
|
|
|
|
def get_bilou(length):
|
|
|
|
|
if length == 1:
|
|
|
|
|
return [U_ENT]
|
|
|
|
|
elif length == 2:
|
|
|
|
|
return [B2_ENT, L2_ENT]
|
|
|
|
|
elif length == 3:
|
|
|
|
|
return [B3_ENT, I3_ENT, L3_ENT]
|
|
|
|
|
elif length == 4:
|
|
|
|
|
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
|
|
|
|
|
elif length == 5:
|
|
|
|
|
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
|
|
|
|
|
elif length == 6:
|
|
|
|
|
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
|
|
|
|
|
elif length == 7:
|
|
|
|
|
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
|
|
|
|
|
elif length == 8:
|
|
|
|
|
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
|
|
|
|
|
elif length == 9:
|
|
|
|
|
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
|
|
|
|
|
elif length == 10:
|
|
|
|
|
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
|
|
|
|
|
I10_ENT, I10_ENT, L10_ENT]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Max length currently 10 for phrase matching")
|
|
|
|
|
|
|
|
|
|
|
2015-10-08 18:00:45 +03:00
|
|
|
|
cdef class PhraseMatcher:
|
|
|
|
|
cdef Pool mem
|
|
|
|
|
cdef Vocab vocab
|
|
|
|
|
cdef Matcher matcher
|
|
|
|
|
cdef PreshMap phrase_ids
|
|
|
|
|
|
|
|
|
|
cdef int max_length
|
|
|
|
|
cdef attr_t* _phrase_key
|
|
|
|
|
|
|
|
|
|
def __init__(self, Vocab vocab, phrases, max_length=10):
|
|
|
|
|
self.mem = Pool()
|
|
|
|
|
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.vocab = vocab
|
|
|
|
|
self.matcher = Matcher(self.vocab, {})
|
|
|
|
|
self.phrase_ids = PreshMap()
|
|
|
|
|
for phrase in phrases:
|
|
|
|
|
if len(phrase) < max_length:
|
|
|
|
|
self.add(phrase)
|
|
|
|
|
|
|
|
|
|
abstract_patterns = []
|
|
|
|
|
for length in range(1, max_length):
|
|
|
|
|
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
|
2016-11-08 19:14:26 +03:00
|
|
|
|
self.matcher.add('Candidate', 'MWE', {}, abstract_patterns, acceptor=self.accept_match)
|
2015-10-08 18:00:45 +03:00
|
|
|
|
|
|
|
|
|
def add(self, Doc tokens):
|
|
|
|
|
cdef int length = tokens.length
|
|
|
|
|
assert length < self.max_length
|
|
|
|
|
tags = get_bilou(length)
|
|
|
|
|
assert len(tags) == length, length
|
2016-11-08 19:14:26 +03:00
|
|
|
|
|
2015-10-08 18:00:45 +03:00
|
|
|
|
cdef int i
|
|
|
|
|
for i in range(self.max_length):
|
|
|
|
|
self._phrase_key[i] = 0
|
|
|
|
|
for i, tag in enumerate(tags):
|
2015-11-03 16:15:14 +03:00
|
|
|
|
lexeme = self.vocab[tokens.c[i].lex.orth]
|
2015-10-08 18:00:45 +03:00
|
|
|
|
lexeme.set_flag(tag, True)
|
|
|
|
|
self._phrase_key[i] = lexeme.orth
|
|
|
|
|
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
|
|
|
|
self.phrase_ids[key] = True
|
|
|
|
|
|
|
|
|
|
def __call__(self, Doc doc):
|
|
|
|
|
matches = []
|
2016-11-08 19:14:26 +03:00
|
|
|
|
for ent_id, label, start, end in self.matcher(doc):
|
2015-10-08 18:00:45 +03:00
|
|
|
|
cand = doc[start : end]
|
|
|
|
|
start = cand[0].idx
|
|
|
|
|
end = cand[-1].idx + len(cand[-1])
|
|
|
|
|
matches.append((start, end, cand.root.tag_, cand.text, 'MWE'))
|
|
|
|
|
for match in matches:
|
|
|
|
|
doc.merge(*match)
|
|
|
|
|
return matches
|
|
|
|
|
|
2016-02-03 04:04:55 +03:00
|
|
|
|
def pipe(self, stream, batch_size=1000, n_threads=2):
|
|
|
|
|
for doc in stream:
|
|
|
|
|
self(doc)
|
|
|
|
|
yield doc
|
|
|
|
|
|
2016-11-08 19:14:26 +03:00
|
|
|
|
def accept_match(self, Doc doc, int ent_id, int label, int start, int end):
|
2015-10-08 18:00:45 +03:00
|
|
|
|
assert (end - start) < self.max_length
|
|
|
|
|
cdef int i, j
|
|
|
|
|
for i in range(self.max_length):
|
|
|
|
|
self._phrase_key[i] = 0
|
|
|
|
|
for i, j in enumerate(range(start, end)):
|
2015-11-03 16:15:14 +03:00
|
|
|
|
self._phrase_key[i] = doc.c[j].lex.orth
|
2015-10-08 18:00:45 +03:00
|
|
|
|
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
|
|
|
|
if self.phrase_ids.get(key):
|
2016-11-08 19:14:26 +03:00
|
|
|
|
return (ent_id, label, start, end)
|
2015-10-08 18:00:45 +03:00
|
|
|
|
else:
|
|
|
|
|
return False
|