mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
* Fix phrase matcher
This commit is contained in:
parent
b3a70e6375
commit
801d55a6d9
|
@ -1,11 +1,18 @@
|
|||
# cython: profile=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from os import path
|
||||
|
||||
from .typedefs cimport attr_t
|
||||
from .typedefs cimport hash_t
|
||||
from .attrs cimport attr_id_t
|
||||
from .structs cimport TokenC
|
||||
from .structs cimport TokenC, LexemeC
|
||||
from .lexeme cimport Lexeme
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
from preshed.maps cimport PreshMap
|
||||
from libcpp.vector cimport vector
|
||||
from murmurhash.mrmr cimport hash64
|
||||
|
||||
from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE
|
||||
from .attrs cimport FLAG13, FLAG14, FLAG15, FLAG16, FLAG17, FLAG18, FLAG19, FLAG20, FLAG21, FLAG22, FLAG23, FLAG24, FLAG25
|
||||
|
@ -15,6 +22,38 @@ from .vocab cimport Vocab
|
|||
|
||||
from libcpp.vector cimport vector
|
||||
|
||||
from .attrs import FLAG61 as U_ENT
|
||||
|
||||
from .attrs import FLAG60 as B2_ENT
|
||||
from .attrs import FLAG59 as B3_ENT
|
||||
from .attrs import FLAG58 as B4_ENT
|
||||
from .attrs import FLAG57 as B5_ENT
|
||||
from .attrs import FLAG56 as B6_ENT
|
||||
from .attrs import FLAG55 as B7_ENT
|
||||
from .attrs import FLAG54 as B8_ENT
|
||||
from .attrs import FLAG53 as B9_ENT
|
||||
from .attrs import FLAG52 as B10_ENT
|
||||
|
||||
from .attrs import FLAG51 as I3_ENT
|
||||
from .attrs import FLAG50 as I4_ENT
|
||||
from .attrs import FLAG49 as I5_ENT
|
||||
from .attrs import FLAG48 as I6_ENT
|
||||
from .attrs import FLAG47 as I7_ENT
|
||||
from .attrs import FLAG46 as I8_ENT
|
||||
from .attrs import FLAG45 as I9_ENT
|
||||
from .attrs import FLAG44 as I10_ENT
|
||||
|
||||
from .attrs import FLAG43 as L2_ENT
|
||||
from .attrs import FLAG42 as L3_ENT
|
||||
from .attrs import FLAG41 as L4_ENT
|
||||
from .attrs import FLAG40 as L5_ENT
|
||||
from .attrs import FLAG39 as L6_ENT
|
||||
from .attrs import FLAG38 as L7_ENT
|
||||
from .attrs import FLAG37 as L8_ENT
|
||||
from .attrs import FLAG36 as L9_ENT
|
||||
from .attrs import FLAG35 as L10_ENT
|
||||
|
||||
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
|
@ -41,7 +80,7 @@ cdef Pattern* init_pattern(Pool mem, object token_specs, attr_t entity_type) exc
|
|||
pattern[i].spec[j].attr = attr
|
||||
pattern[i].spec[j].value = value
|
||||
i = len(token_specs)
|
||||
pattern[i].spec = <AttrValue*>mem.alloc(1, sizeof(AttrValue))
|
||||
pattern[i].spec = <AttrValue*>mem.alloc(2, sizeof(AttrValue))
|
||||
pattern[i].spec[0].attr = ENT_TYPE
|
||||
pattern[i].spec[0].value = entity_type
|
||||
pattern[i].spec[1].attr = LENGTH
|
||||
|
@ -81,7 +120,33 @@ def _convert_strings(token_specs, string_store):
|
|||
value = int(value)
|
||||
converted[-1].append((attr, value))
|
||||
return converted
|
||||
|
||||
|
||||
|
||||
def get_bilou(length):
|
||||
if length == 1:
|
||||
return [U_ENT]
|
||||
elif length == 2:
|
||||
return [B2_ENT, L2_ENT]
|
||||
elif length == 3:
|
||||
return [B3_ENT, I3_ENT, L3_ENT]
|
||||
elif length == 4:
|
||||
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
|
||||
elif length == 5:
|
||||
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
|
||||
elif length == 6:
|
||||
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
|
||||
elif length == 7:
|
||||
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
|
||||
elif length == 8:
|
||||
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
|
||||
elif length == 9:
|
||||
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
|
||||
elif length == 10:
|
||||
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
|
||||
I10_ENT, I10_ENT, L10_ENT]
|
||||
else:
|
||||
raise ValueError("Max length currently 10 for phrase matching")
|
||||
|
||||
|
||||
def map_attr_name(attr):
|
||||
attr = attr.upper()
|
||||
|
@ -95,32 +160,6 @@ def map_attr_name(attr):
|
|||
return SHAPE
|
||||
elif attr == 'NORM':
|
||||
return NORM
|
||||
elif attr == 'FLAG13':
|
||||
return FLAG13
|
||||
elif attr == 'FLAG14':
|
||||
return FLAG14
|
||||
elif attr == 'FLAG15':
|
||||
return FLAG15
|
||||
elif attr == 'FLAG16':
|
||||
return FLAG16
|
||||
elif attr == 'FLAG17':
|
||||
return FLAG17
|
||||
elif attr == 'FLAG18':
|
||||
return FLAG18
|
||||
elif attr == 'FLAG19':
|
||||
return FLAG19
|
||||
elif attr == 'FLAG20':
|
||||
return FLAG20
|
||||
elif attr == 'FLAG21':
|
||||
return FLAG21
|
||||
elif attr == 'FLAG22':
|
||||
return FLAG22
|
||||
elif attr == 'FLAG23':
|
||||
return FLAG23
|
||||
elif attr == 'FLAG24':
|
||||
return FLAG24
|
||||
elif attr == 'FLAG25':
|
||||
return FLAG25
|
||||
else:
|
||||
raise Exception("TODO: Finish supporting attr mapping %s" % attr)
|
||||
|
||||
|
@ -163,7 +202,7 @@ cdef class Matcher:
|
|||
spec = _convert_strings(spec, self.vocab.strings)
|
||||
self.patterns.push_back(init_pattern(self.mem, spec, etype))
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
def __call__(self, Doc doc, acceptor=None):
|
||||
cdef vector[Pattern*] partials
|
||||
cdef int n_partials = 0
|
||||
cdef int q = 0
|
||||
|
@ -174,21 +213,94 @@ cdef class Matcher:
|
|||
for token_i in range(doc.length):
|
||||
token = &doc.data[token_i]
|
||||
q = 0
|
||||
# Go over the open matches, extending or finalizing if able. Otherwise,
|
||||
# we over-write them (q doesn't advance)
|
||||
for i in range(partials.size()):
|
||||
state = partials.at(i)
|
||||
if match(state, token):
|
||||
if is_final(state):
|
||||
matches.append(get_entity(state, token, token_i))
|
||||
label, start, end = get_entity(state, token, token_i)
|
||||
if acceptor is None or acceptor(doc, label, start, end):
|
||||
matches.append((label, start, end))
|
||||
else:
|
||||
partials[q] = state + 1
|
||||
q += 1
|
||||
partials.resize(q)
|
||||
# Check whether we open any new patterns on this token
|
||||
for i in range(self.n_patterns):
|
||||
state = self.patterns[i]
|
||||
if match(state, token):
|
||||
if is_final(state):
|
||||
matches.append(get_entity(state, token, token_i))
|
||||
label, start, end = get_entity(state, token, token_i)
|
||||
if acceptor is None or acceptor(doc, label, start, end):
|
||||
matches.append((label, start, end))
|
||||
else:
|
||||
partials.push_back(state + 1)
|
||||
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + matches
|
||||
return matches
|
||||
|
||||
|
||||
cdef class PhraseMatcher:
|
||||
cdef Pool mem
|
||||
cdef Vocab vocab
|
||||
cdef Matcher matcher
|
||||
cdef PreshMap phrase_ids
|
||||
|
||||
cdef int max_length
|
||||
cdef attr_t* _phrase_key
|
||||
|
||||
def __init__(self, Vocab vocab, phrases, max_length=10):
|
||||
self.mem = Pool()
|
||||
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
|
||||
self.max_length = max_length
|
||||
self.vocab = vocab
|
||||
self.matcher = Matcher(self.vocab, {})
|
||||
self.phrase_ids = PreshMap()
|
||||
for phrase in phrases:
|
||||
if len(phrase) < max_length:
|
||||
self.add(phrase)
|
||||
|
||||
abstract_patterns = []
|
||||
for length in range(1, max_length):
|
||||
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
|
||||
self.matcher.add('Candidate', 'MWE', {}, abstract_patterns)
|
||||
|
||||
def add(self, Doc tokens):
|
||||
cdef int length = tokens.length
|
||||
assert length < self.max_length
|
||||
tags = get_bilou(length)
|
||||
assert len(tags) == length, length
|
||||
|
||||
cdef int i
|
||||
for i in range(self.max_length):
|
||||
self._phrase_key[i] = 0
|
||||
for i, tag in enumerate(tags):
|
||||
lexeme = self.vocab[tokens.data[i].lex.orth]
|
||||
lexeme.set_flag(tag, True)
|
||||
self._phrase_key[i] = lexeme.orth
|
||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||
self.phrase_ids[key] = True
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
matches = []
|
||||
for label, start, end in self.matcher(doc, acceptor=self.accept_match):
|
||||
cand = doc[start : end]
|
||||
start = cand[0].idx
|
||||
end = cand[-1].idx + len(cand[-1])
|
||||
matches.append((start, end, cand.root.tag_, cand.text, 'MWE'))
|
||||
for match in matches:
|
||||
doc.merge(*match)
|
||||
return matches
|
||||
|
||||
def accept_match(self, Doc doc, int label, int start, int end):
|
||||
assert (end - start) < self.max_length
|
||||
cdef int i, j
|
||||
for i in range(self.max_length):
|
||||
self._phrase_key[i] = 0
|
||||
for i, j in enumerate(range(start, end)):
|
||||
self._phrase_key[i] = doc.data[j].lex.orth
|
||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||
if self.phrase_ids.get(key):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
|
Loading…
Reference in New Issue
Block a user