Merge pull request #5498 from adrianeboyd/bugfix/phrasematcher-unpickle-new-api

This commit is contained in:
Ines Montani 2020-05-25 13:25:07 +02:00 committed by GitHub
commit ade4767e06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 1 deletions

View File

@ -332,7 +332,7 @@ def unpickle_matcher(vocab, docs, callbacks, attr):
matcher = PhraseMatcher(vocab, attr=attr) matcher = PhraseMatcher(vocab, attr=attr)
for key, specs in docs.items(): for key, specs in docs.items():
callback = callbacks.get(key, None) callback = callbacks.get(key, None)
matcher.add(key, callback, *specs) matcher.add(key, specs, on_match=callback)
return matcher return matcher

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import pytest import pytest
import srsly
from mock import Mock from mock import Mock
from spacy.matcher import PhraseMatcher from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc from spacy.tokens import Doc
@ -266,3 +267,26 @@ def test_phrase_matcher_basic_check(en_vocab):
pattern = Doc(en_vocab, words=["hello", "world"]) pattern = Doc(en_vocab, words=["hello", "world"])
with pytest.raises(ValueError): with pytest.raises(ValueError):
matcher.add("TEST", pattern) matcher.add("TEST", pattern)
def test_phrase_matcher_pickle(en_vocab):
matcher = PhraseMatcher(en_vocab)
mock = Mock()
matcher.add("TEST", [Doc(en_vocab, words=["test"])])
matcher.add("TEST2", [Doc(en_vocab, words=["test2"])], on_match=mock)
doc = Doc(en_vocab, words=["these", "are", "tests", ":", "test", "test2"])
assert len(matcher) == 2
b = srsly.pickle_dumps(matcher)
matcher_unpickled = srsly.pickle_loads(b)
# call after pickling to avoid recursion error related to mock
matches = matcher(doc)
matches_unpickled = matcher_unpickled(doc)
assert len(matcher) == len(matcher_unpickled)
assert matches == matches_unpickled
# clunky way to vaguely check that callback is unpickled
(vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1]
assert isinstance(callbacks.get("TEST2"), Mock)