Fix PhraseMatcher to remember attr on pickling (#4336)

* Fix PhraseMatcher to remember attr on pickling

* Check for attr as int or long
This commit is contained in:
adrianeboyd 2019-09-29 17:12:33 +02:00 committed by Ines Montani
parent 089f44cc56
commit ba5595c764

View File

@ -49,7 +49,7 @@ cdef class PhraseMatcher:
self._terminal_hash = 826361138722620965
map_init(self.mem, self.c_map, 8)
if isinstance(attr, long):
if isinstance(attr, (int, long)):
self.attr = attr
else:
attr = attr.upper()
@ -79,7 +79,7 @@ cdef class PhraseMatcher:
return key in self._callbacks
def __reduce__(self):
data = (self.vocab, self._docs, self._callbacks)
data = (self.vocab, self._docs, self._callbacks, self.attr)
return (unpickle_matcher, data, None, None)
def remove(self, key):
@ -171,15 +171,15 @@ cdef class PhraseMatcher:
for doc in docs:
if len(doc) == 0:
continue
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
raise ValueError(Errors.E155.format())
if self.attr == DEP and not doc.is_parsed:
raise ValueError(Errors.E156.format())
if self._validate and (doc.is_tagged or doc.is_parsed) \
and self.attr not in (DEP, POS, TAG, LEMMA):
string_attr = self.vocab.strings[self.attr]
user_warning(Warnings.W012.format(key=key, attr=string_attr))
if isinstance(doc, Doc):
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
raise ValueError(Errors.E155.format())
if self.attr == DEP and not doc.is_parsed:
raise ValueError(Errors.E156.format())
if self._validate and (doc.is_tagged or doc.is_parsed) \
and self.attr not in (DEP, POS, TAG, LEMMA):
string_attr = self.vocab.strings[self.attr]
user_warning(Warnings.W012.format(key=key, attr=string_attr))
keyword = self._convert_to_array(doc)
else:
keyword = doc
@ -310,8 +310,8 @@ cdef class PhraseMatcher:
return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))]
def unpickle_matcher(vocab, docs, callbacks):
matcher = PhraseMatcher(vocab)
def unpickle_matcher(vocab, docs, callbacks, attr):
matcher = PhraseMatcher(vocab, attr=attr)
for key, specs in docs.items():
callback = callbacks.get(key, None)
matcher.add(key, callback, *specs)