Switch to new add API in PhraseMatcher unpickle

This commit is contained in:
Adriane Boyd 2020-05-25 10:13:56 +02:00
parent ae1c179f3a
commit e06ca7ea24
2 changed files with 25 additions and 1 deletions

View File

@ -332,7 +332,7 @@ def unpickle_matcher(vocab, docs, callbacks, attr):
matcher = PhraseMatcher(vocab, attr=attr)
for key, specs in docs.items():
callback = callbacks.get(key, None)
matcher.add(key, callback, *specs)
matcher.add(key, specs, on_match=callback)
return matcher

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals
import pytest
import srsly
from mock import Mock
from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc
@ -266,3 +267,26 @@ def test_phrase_matcher_basic_check(en_vocab):
pattern = Doc(en_vocab, words=["hello", "world"])
with pytest.raises(ValueError):
matcher.add("TEST", pattern)
def test_phrase_matcher_pickle(en_vocab):
matcher = PhraseMatcher(en_vocab)
mock = Mock()
matcher.add("TEST", [Doc(en_vocab, words=["test"])])
matcher.add("TEST2", [Doc(en_vocab, words=["test2"])], on_match=mock)
doc = Doc(en_vocab, words=["these", "are", "tests", ":", "test", "test2"])
assert len(matcher) == 2
b = srsly.pickle_dumps(matcher)
matcher_unpickled = srsly.pickle_loads(b)
# call after pickling to avoid recursion error related to mock
matches = matcher(doc)
matches_unpickled = matcher_unpickled(doc)
assert len(matcher) == len(matcher_unpickled)
assert matches == matches_unpickled
# clunky way to vaguely check that callback is unpickled
(vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1]
assert isinstance(callbacks.get("TEST2"), Mock)