mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-04 21:50:35 +03:00
Merge pull request #5498 from adrianeboyd/bugfix/phrasematcher-unpickle-new-api
This commit is contained in:
commit
ade4767e06
|
@ -332,7 +332,7 @@ def unpickle_matcher(vocab, docs, callbacks, attr):
|
||||||
matcher = PhraseMatcher(vocab, attr=attr)
|
matcher = PhraseMatcher(vocab, attr=attr)
|
||||||
for key, specs in docs.items():
|
for key, specs in docs.items():
|
||||||
callback = callbacks.get(key, None)
|
callback = callbacks.get(key, None)
|
||||||
matcher.add(key, callback, *specs)
|
matcher.add(key, specs, on_match=callback)
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import srsly
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
from spacy.matcher import PhraseMatcher
|
from spacy.matcher import PhraseMatcher
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
@ -266,3 +267,26 @@ def test_phrase_matcher_basic_check(en_vocab):
|
||||||
pattern = Doc(en_vocab, words=["hello", "world"])
|
pattern = Doc(en_vocab, words=["hello", "world"])
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
matcher.add("TEST", pattern)
|
matcher.add("TEST", pattern)
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_pickle(en_vocab):
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
mock = Mock()
|
||||||
|
matcher.add("TEST", [Doc(en_vocab, words=["test"])])
|
||||||
|
matcher.add("TEST2", [Doc(en_vocab, words=["test2"])], on_match=mock)
|
||||||
|
doc = Doc(en_vocab, words=["these", "are", "tests", ":", "test", "test2"])
|
||||||
|
assert len(matcher) == 2
|
||||||
|
|
||||||
|
b = srsly.pickle_dumps(matcher)
|
||||||
|
matcher_unpickled = srsly.pickle_loads(b)
|
||||||
|
|
||||||
|
# call after pickling to avoid recursion error related to mock
|
||||||
|
matches = matcher(doc)
|
||||||
|
matches_unpickled = matcher_unpickled(doc)
|
||||||
|
|
||||||
|
assert len(matcher) == len(matcher_unpickled)
|
||||||
|
assert matches == matches_unpickled
|
||||||
|
|
||||||
|
# clunky way to vaguely check that callback is unpickled
|
||||||
|
(vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1]
|
||||||
|
assert isinstance(callbacks.get("TEST2"), Mock)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user