Set as_tuples on Doc during processing (#9592)

* Set as_tuples on Doc during processing

* Fix types

* Format
This commit is contained in:
Adriane Boyd 2021-11-02 15:08:22 +01:00 committed by GitHub
parent 667572adca
commit 5a979137a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 9 deletions

View File

@ -1091,6 +1091,12 @@ class Language:
return self.make_doc(doc_like)
raise ValueError(Errors.E866.format(type=type(doc_like)))
def _ensure_doc_with_context(self, doc_like: Union[str, Doc], context: Any) -> Doc:
"""Create a Doc if need be and add as_tuples context, or raise an error if the input is not a Doc or a string."""
doc = self._ensure_doc(doc_like)
doc._context = context
return doc
def update(
self,
examples: Iterable[Example],
@ -1474,7 +1480,7 @@ class Language:
@overload
def pipe( # noqa: F811
self,
texts: Iterable[Tuple[str, _AnyContext]],
texts: Iterable[Tuple[Union[str, Doc], _AnyContext]],
*,
as_tuples: Literal[True] = ...,
batch_size: Optional[int] = ...,
@ -1486,7 +1492,9 @@ class Language:
def pipe( # noqa: F811
self,
texts: Union[Iterable[Union[str, Doc]], Iterable[Tuple[str, _AnyContext]]],
texts: Union[
Iterable[Union[str, Doc]], Iterable[Tuple[Union[str, Doc], _AnyContext]]
],
*,
as_tuples: bool = False,
batch_size: Optional[int] = None,
@ -1512,18 +1520,20 @@ class Language:
"""
# Handle texts with context as tuples
if as_tuples:
texts = cast(Iterable[Tuple[str, _AnyContext]], texts)
text_context1, text_context2 = itertools.tee(texts)
texts = (tc[0] for tc in text_context1)
contexts = (tc[1] for tc in text_context2)
texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
docs_with_contexts = (
self._ensure_doc_with_context(text, context) for text, context in texts
)
docs = self.pipe(
texts,
docs_with_contexts,
batch_size=batch_size,
disable=disable,
n_process=n_process,
component_cfg=component_cfg,
)
for doc, context in zip(docs, contexts):
for doc in docs:
context = doc._context
doc._context = None
yield (doc, context)
return

View File

@ -56,7 +56,7 @@ cdef class Doc:
cdef public bint has_unknown_spaces
cdef public list _py_tokens
cdef public object _context
cdef int length
cdef int max_length

View File

@ -29,6 +29,7 @@ class Doc:
tensor: numpy.ndarray
user_data: Dict[str, Any]
has_unknown_spaces: bool
_context: Any
@classmethod
def set_extension(
cls,