load Underscore state when multiprocessing

This commit is contained in:
svlandeg 2020-02-12 11:50:42 +01:00
parent 99a543367d
commit ecbb9c4b9f
2 changed files with 16 additions and 3 deletions

View File

@ -15,6 +15,7 @@ import multiprocessing as mp
from itertools import chain, cycle from itertools import chain, cycle
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .tokens.underscore import Underscore
from .vocab import Vocab from .vocab import Vocab
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .lookups import Lookups from .lookups import Lookups
@ -852,7 +853,10 @@ class Language(object):
sender.send() sender.send()
procs = [ procs = [
mp.Process(target=_apply_pipes, args=(self.make_doc, pipes, rch, sch)) mp.Process(
target=_apply_pipes,
args=(self.make_doc, pipes, rch, sch, Underscore.get_state()),
)
for rch, sch in zip(texts_q, bytedocs_send_ch) for rch, sch in zip(texts_q, bytedocs_send_ch)
] ]
for proc in procs: for proc in procs:
@ -1107,7 +1111,7 @@ def _pipe(docs, proc, kwargs):
yield doc yield doc
def _apply_pipes(make_doc, pipes, reciever, sender): def _apply_pipes(make_doc, pipes, receiver, sender, underscore_state):
"""Worker for Language.pipe """Worker for Language.pipe
receiver (multiprocessing.Connection): Pipe to receive text. Usually receiver (multiprocessing.Connection): Pipe to receive text. Usually
@ -1115,8 +1119,9 @@ def _apply_pipes(make_doc, pipes, reciever, sender):
sender (multiprocessing.Connection): Pipe to send doc. Usually created by sender (multiprocessing.Connection): Pipe to send doc. Usually created by
`multiprocessing.Pipe()` `multiprocessing.Pipe()`
""" """
Underscore.load_state(underscore_state)
while True: while True:
texts = reciever.get() texts = receiver.get()
docs = (make_doc(text) for text in texts) docs = (make_doc(text) for text in texts)
for pipe in pipes: for pipe in pipes:
docs = pipe(docs) docs = pipe(docs)

View File

@ -79,6 +79,14 @@ class Underscore(object):
def _get_key(self, name): def _get_key(self, name):
return ("._.", name, self._start, self._end) return ("._.", name, self._start, self._end)
@classmethod
def get_state(cls):
return cls.token_extensions, cls.span_extensions, cls.doc_extensions
@classmethod
def load_state(cls, state):
cls.token_extensions, cls.span_extensions, cls.doc_extensions = state
def get_ext_args(**kwargs): def get_ext_args(**kwargs):
"""Validate and convert arguments. Reused in Doc, Token and Span.""" """Validate and convert arguments. Reused in Doc, Token and Span."""