Merge pull request #5006 from svlandeg/bugfix/multiproc-underscore

load Underscore state when multiprocessing
This commit is contained in:
Ines Montani 2020-02-25 14:46:02 +01:00 committed by GitHub
commit 4440a072d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 75 additions and 4 deletions

View File

@ -15,6 +15,7 @@ import multiprocessing as mp
from itertools import chain, cycle
from .tokenizer import Tokenizer
from .tokens.underscore import Underscore
from .vocab import Vocab
from .lemmatizer import Lemmatizer
from .lookups import Lookups
@ -853,7 +854,10 @@ class Language(object):
sender.send()
procs = [
mp.Process(target=_apply_pipes, args=(self.make_doc, pipes, rch, sch))
mp.Process(
target=_apply_pipes,
args=(self.make_doc, pipes, rch, sch, Underscore.get_state()),
)
for rch, sch in zip(texts_q, bytedocs_send_ch)
]
for proc in procs:
@ -1108,16 +1112,18 @@ def _pipe(docs, proc, kwargs):
yield doc
def _apply_pipes(make_doc, pipes, reciever, sender):
def _apply_pipes(make_doc, pipes, receiver, sender, underscore_state):
"""Worker for Language.pipe
receiver (multiprocessing.Connection): Pipe to receive text. Usually
created by `multiprocessing.Pipe()`
sender (multiprocessing.Connection): Pipe to send doc. Usually created by
`multiprocessing.Pipe()`
underscore_state (tuple): The data in the Underscore class of the parent
"""
Underscore.load_state(underscore_state)
while True:
texts = reciever.get()
texts = receiver.get()
docs = (make_doc(text) for text in texts)
for pipe in pipes:
docs = pipe(docs)

View File

@ -7,6 +7,15 @@ from spacy.tokens import Doc, Span, Token
from spacy.tokens.underscore import Underscore
@pytest.fixture(scope="function", autouse=True)
def clean_underscore():
# reset the Underscore object after the test, to avoid having state copied across tests
yield
Underscore.doc_extensions = {}
Underscore.span_extensions = {}
Underscore.token_extensions = {}
def test_create_doc_underscore():
doc = Mock()
doc.doc = doc

View File

@ -6,6 +6,7 @@ import re
from mock import Mock
from spacy.matcher import Matcher, DependencyMatcher
from spacy.tokens import Doc, Token
from ..doc.test_underscore import clean_underscore
@pytest.fixture
@ -200,6 +201,7 @@ def test_matcher_any_token_operator(en_vocab):
assert matches[2] == "test hello world"
@pytest.mark.usefixtures("clean_underscore")
def test_matcher_extension_attribute(en_vocab):
matcher = Matcher(en_vocab)
get_is_fruit = lambda token: token.text in ("apple", "banana")

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
from spacy.lang.en import English
from spacy.pipeline import EntityRuler
from spacy.tokens.underscore import Underscore
def test_issue4849():

View File

@ -0,0 +1,45 @@
# coding: utf8
from __future__ import unicode_literals
import spacy
from spacy.lang.en import English
from spacy.tokens import Span, Doc
from spacy.tokens.underscore import Underscore
class CustomPipe:
name = "my_pipe"
def __init__(self):
Span.set_extension("my_ext", getter=self._get_my_ext)
Doc.set_extension("my_ext", default=None)
def __call__(self, doc):
gathered_ext = []
for sent in doc.sents:
sent_ext = self._get_my_ext(sent)
sent._.set("my_ext", sent_ext)
gathered_ext.append(sent_ext)
doc._.set("my_ext", "\n".join(gathered_ext))
return doc
@staticmethod
def _get_my_ext(span):
return str(span.end)
def test_issue4903():
# ensures that this runs correctly and doesn't hang or crash on Windows / macOS
nlp = English()
custom_component = CustomPipe()
nlp.add_pipe(nlp.create_pipe("sentencizer"))
nlp.add_pipe(custom_component, after="sentencizer")
text = ["I like bananas.", "Do you like them?", "No, I prefer wasabi."]
docs = list(nlp.pipe(text, n_process=2))
assert docs[0].text == "I like bananas."
assert docs[1].text == "Do you like them?"
assert docs[2].text == "No, I prefer wasabi."

View File

@ -11,6 +11,6 @@ def nlp():
return spacy.blank("en")
def test_evaluate(nlp):
def test_issue4924(nlp):
docs_golds = [("", {})]
nlp.evaluate(docs_golds)

View File

@ -79,6 +79,14 @@ class Underscore(object):
def _get_key(self, name):
return ("._.", name, self._start, self._end)
@classmethod
def get_state(cls):
return cls.token_extensions, cls.span_extensions, cls.doc_extensions
@classmethod
def load_state(cls, state):
cls.token_extensions, cls.span_extensions, cls.doc_extensions = state
def get_ext_args(**kwargs):
"""Validate and convert arguments. Reused in Doc, Token and Span."""