mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Set as_tuples on Doc during processing (#9592)
* Set as_tuples on Doc during processing * Fix types * Format
This commit is contained in:
parent
667572adca
commit
5a979137a7
|
@ -1091,6 +1091,12 @@ class Language:
|
|||
return self.make_doc(doc_like)
|
||||
raise ValueError(Errors.E866.format(type=type(doc_like)))
|
||||
|
||||
def _ensure_doc_with_context(self, doc_like: Union[str, Doc], context: Any) -> Doc:
|
||||
"""Create a Doc if need be and add as_tuples context, or raise an error if the input is not a Doc or a string."""
|
||||
doc = self._ensure_doc(doc_like)
|
||||
doc._context = context
|
||||
return doc
|
||||
|
||||
def update(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
|
@ -1474,7 +1480,7 @@ class Language:
|
|||
@overload
|
||||
def pipe( # noqa: F811
|
||||
self,
|
||||
texts: Iterable[Tuple[str, _AnyContext]],
|
||||
texts: Iterable[Tuple[Union[str, Doc], _AnyContext]],
|
||||
*,
|
||||
as_tuples: Literal[True] = ...,
|
||||
batch_size: Optional[int] = ...,
|
||||
|
@ -1486,7 +1492,9 @@ class Language:
|
|||
|
||||
def pipe( # noqa: F811
|
||||
self,
|
||||
texts: Union[Iterable[Union[str, Doc]], Iterable[Tuple[str, _AnyContext]]],
|
||||
texts: Union[
|
||||
Iterable[Union[str, Doc]], Iterable[Tuple[Union[str, Doc], _AnyContext]]
|
||||
],
|
||||
*,
|
||||
as_tuples: bool = False,
|
||||
batch_size: Optional[int] = None,
|
||||
|
@ -1512,18 +1520,20 @@ class Language:
|
|||
"""
|
||||
# Handle texts with context as tuples
|
||||
if as_tuples:
|
||||
texts = cast(Iterable[Tuple[str, _AnyContext]], texts)
|
||||
text_context1, text_context2 = itertools.tee(texts)
|
||||
texts = (tc[0] for tc in text_context1)
|
||||
contexts = (tc[1] for tc in text_context2)
|
||||
texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
|
||||
docs_with_contexts = (
|
||||
self._ensure_doc_with_context(text, context) for text, context in texts
|
||||
)
|
||||
docs = self.pipe(
|
||||
texts,
|
||||
docs_with_contexts,
|
||||
batch_size=batch_size,
|
||||
disable=disable,
|
||||
n_process=n_process,
|
||||
component_cfg=component_cfg,
|
||||
)
|
||||
for doc, context in zip(docs, contexts):
|
||||
for doc in docs:
|
||||
context = doc._context
|
||||
doc._context = None
|
||||
yield (doc, context)
|
||||
return
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ cdef class Doc:
|
|||
|
||||
cdef public bint has_unknown_spaces
|
||||
|
||||
cdef public list _py_tokens
|
||||
cdef public object _context
|
||||
|
||||
cdef int length
|
||||
cdef int max_length
|
||||
|
|
|
@ -29,6 +29,7 @@ class Doc:
|
|||
tensor: numpy.ndarray
|
||||
user_data: Dict[str, Any]
|
||||
has_unknown_spaces: bool
|
||||
_context: Any
|
||||
@classmethod
|
||||
def set_extension(
|
||||
cls,
|
||||
|
|
Loading…
Reference in New Issue
Block a user