Set as_tuples on Doc during processing (#9592)

* Set as_tuples on Doc during processing

* Fix types

* Format
This commit is contained in:
Adriane Boyd 2021-11-02 15:08:22 +01:00 committed by GitHub
parent 667572adca
commit 5a979137a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 9 deletions

View File

@ -1091,6 +1091,12 @@ class Language:
return self.make_doc(doc_like) return self.make_doc(doc_like)
raise ValueError(Errors.E866.format(type=type(doc_like))) raise ValueError(Errors.E866.format(type=type(doc_like)))
def _ensure_doc_with_context(self, doc_like: Union[str, Doc], context: Any) -> Doc:
"""Create a Doc if need be and add as_tuples context, or raise an error if the input is not a Doc or a string."""
doc = self._ensure_doc(doc_like)
doc._context = context
return doc
def update( def update(
self, self,
examples: Iterable[Example], examples: Iterable[Example],
@ -1474,7 +1480,7 @@ class Language:
@overload @overload
def pipe( # noqa: F811 def pipe( # noqa: F811
self, self,
texts: Iterable[Tuple[str, _AnyContext]], texts: Iterable[Tuple[Union[str, Doc], _AnyContext]],
*, *,
as_tuples: Literal[True] = ..., as_tuples: Literal[True] = ...,
batch_size: Optional[int] = ..., batch_size: Optional[int] = ...,
@ -1486,7 +1492,9 @@ class Language:
def pipe( # noqa: F811 def pipe( # noqa: F811
self, self,
texts: Union[Iterable[Union[str, Doc]], Iterable[Tuple[str, _AnyContext]]], texts: Union[
Iterable[Union[str, Doc]], Iterable[Tuple[Union[str, Doc], _AnyContext]]
],
*, *,
as_tuples: bool = False, as_tuples: bool = False,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
@ -1512,18 +1520,20 @@ class Language:
""" """
# Handle texts with context as tuples # Handle texts with context as tuples
if as_tuples: if as_tuples:
texts = cast(Iterable[Tuple[str, _AnyContext]], texts) texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
text_context1, text_context2 = itertools.tee(texts) docs_with_contexts = (
texts = (tc[0] for tc in text_context1) self._ensure_doc_with_context(text, context) for text, context in texts
contexts = (tc[1] for tc in text_context2) )
docs = self.pipe( docs = self.pipe(
texts, docs_with_contexts,
batch_size=batch_size, batch_size=batch_size,
disable=disable, disable=disable,
n_process=n_process, n_process=n_process,
component_cfg=component_cfg, component_cfg=component_cfg,
) )
for doc, context in zip(docs, contexts): for doc in docs:
context = doc._context
doc._context = None
yield (doc, context) yield (doc, context)
return return

View File

@ -56,7 +56,7 @@ cdef class Doc:
cdef public bint has_unknown_spaces cdef public bint has_unknown_spaces
cdef public list _py_tokens cdef public object _context
cdef int length cdef int length
cdef int max_length cdef int max_length

View File

@ -29,6 +29,7 @@ class Doc:
tensor: numpy.ndarray tensor: numpy.ndarray
user_data: Dict[str, Any] user_data: Dict[str, Any]
has_unknown_spaces: bool has_unknown_spaces: bool
_context: Any
@classmethod @classmethod
def set_extension( def set_extension(
cls, cls,