Fix PhraseMatcher pickling and length (resolves #3248) (#3252)

This commit is contained in:
Ines Montani 2019-02-12 18:27:54 +01:00 committed by GitHub
parent 483dddc9bc
commit b589b945db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 2 deletions

View File

@ -33,6 +33,7 @@ cdef class PhraseMatcher:
cdef attr_id_t attr cdef attr_id_t attr
cdef public object _callbacks cdef public object _callbacks
cdef public object _patterns cdef public object _patterns
cdef public object _docs
cdef public object _validate cdef public object _validate
def __init__(self, Vocab vocab, max_length=0, attr='ORTH', validate=False): def __init__(self, Vocab vocab, max_length=0, attr='ORTH', validate=False):
@ -55,6 +56,7 @@ cdef class PhraseMatcher:
] ]
self.matcher.add('Candidate', None, *abstract_patterns) self.matcher.add('Candidate', None, *abstract_patterns)
self._callbacks = {} self._callbacks = {}
self._docs = {}
self._validate = validate self._validate = validate
def __len__(self): def __len__(self):
@ -64,7 +66,7 @@ cdef class PhraseMatcher:
RETURNS (int): The number of rules. RETURNS (int): The number of rules.
""" """
return len(self.phrase_ids) return len(self._docs)
def __contains__(self, key): def __contains__(self, key):
"""Check whether the matcher contains rules for a match ID. """Check whether the matcher contains rules for a match ID.
@ -76,7 +78,8 @@ cdef class PhraseMatcher:
return ent_id in self._callbacks return ent_id in self._callbacks
def __reduce__(self): def __reduce__(self):
return (self.__class__, (self.vocab,), None, None) data = (self.vocab, self._docs, self._callbacks)
return (unpickle_matcher, data, None, None)
def add(self, key, on_match, *docs): def add(self, key, on_match, *docs):
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
@ -89,6 +92,7 @@ cdef class PhraseMatcher:
cdef Doc doc cdef Doc doc
cdef hash_t ent_id = self.matcher._normalize_key(key) cdef hash_t ent_id = self.matcher._normalize_key(key)
self._callbacks[ent_id] = on_match self._callbacks[ent_id] = on_match
self._docs[ent_id] = docs
cdef int length cdef int length
cdef int i cdef int i
cdef hash_t phrase_hash cdef hash_t phrase_hash
@ -213,3 +217,11 @@ def get_bilou(length):
return [B3_ENT, I3_ENT, L3_ENT] return [B3_ENT, I3_ENT, L3_ENT]
else: else:
return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT] return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]
def unpickle_matcher(vocab, docs, callbacks):
matcher = PhraseMatcher(vocab)
for key, specs in docs.items():
callback = callbacks.get(key, None)
matcher.add(key, callback, *specs)
return matcher

View File

@ -0,0 +1,28 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
from spacy.matcher import PhraseMatcher
from spacy.lang.en import English
from spacy.compat import pickle
def test_issue3248_1():
"""Test that the PhraseMatcher correctly reports its number of rules, not
total number of patterns."""
nlp = English()
matcher = PhraseMatcher(nlp.vocab)
matcher.add("TEST1", None, nlp("a"), nlp("b"), nlp("c"))
matcher.add("TEST2", None, nlp("d"))
assert len(matcher) == 2
def test_issue3248_2():
"""Test that the PhraseMatcher can be pickled correctly."""
nlp = English()
matcher = PhraseMatcher(nlp.vocab)
matcher.add("TEST1", None, nlp("a"), nlp("b"), nlp("c"))
matcher.add("TEST2", None, nlp("d"))
data = pickle.dumps(matcher)
new_matcher = pickle.loads(data)
assert len(new_matcher) == len(matcher)