Support exclude in Doc.from_docs (#10689)

* Support exclude in Doc.from_docs

* Update API docs

* Add new tag to docs
This commit is contained in:
Adriane Boyd 2022-04-25 18:19:03 +02:00 committed by GitHub
parent 3b208197c3
commit 455f089c9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 37 deletions

View File

@ -1,6 +1,7 @@
import weakref import weakref
import numpy import numpy
from numpy.testing import assert_array_equal
import pytest import pytest
from thinc.api import NumpyOps, get_current_ops from thinc.api import NumpyOps, get_current_ops
@ -634,6 +635,14 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
assert "group" in m_doc.spans assert "group" in m_doc.spans
assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]]) assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]])
# can exclude spans
m_doc = Doc.from_docs(en_docs, exclude=["spans"])
assert "group" not in m_doc.spans
# can exclude user_data
m_doc = Doc.from_docs(en_docs, exclude=["user_data"])
assert m_doc.user_data == {}
# can merge empty docs # can merge empty docs
doc = Doc.from_docs([en_tokenizer("")] * 10) doc = Doc.from_docs([en_tokenizer("")] * 10)
@ -647,6 +656,20 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
assert "group" in m_doc.spans assert "group" in m_doc.spans
assert len(m_doc.spans["group"]) == 0 assert len(m_doc.spans["group"]) == 0
# with tensor
ops = get_current_ops()
for doc in en_docs:
doc.tensor = ops.asarray([[len(t.text), 0.0] for t in doc])
m_doc = Doc.from_docs(en_docs)
assert_array_equal(
ops.to_numpy(m_doc.tensor),
ops.to_numpy(ops.xp.vstack([doc.tensor for doc in en_docs if len(doc)])),
)
# can exclude tensor
m_doc = Doc.from_docs(en_docs, exclude=["tensor"])
assert m_doc.tensor.shape == (0,)
def test_doc_api_from_docs_ents(en_tokenizer): def test_doc_api_from_docs_ents(en_tokenizer):
texts = ["Merging the docs is fun.", "They don't think alike."] texts = ["Merging the docs is fun.", "They don't think alike."]

View File

@ -11,7 +11,7 @@ from enum import Enum
import itertools import itertools
import numpy import numpy
import srsly import srsly
from thinc.api import get_array_module from thinc.api import get_array_module, get_current_ops
from thinc.util import copy_array from thinc.util import copy_array
import warnings import warnings
@ -1108,14 +1108,19 @@ cdef class Doc:
return self return self
@staticmethod @staticmethod
def from_docs(docs, ensure_whitespace=True, attrs=None): def from_docs(docs, ensure_whitespace=True, attrs=None, *, exclude=tuple()):
"""Concatenate multiple Doc objects to form a new one. Raises an error """Concatenate multiple Doc objects to form a new one. Raises an error
if the `Doc` objects do not all share the same `Vocab`. if the `Doc` objects do not all share the same `Vocab`.
docs (list): A list of Doc objects. docs (list): A list of Doc objects.
ensure_whitespace (bool): Insert a space between two adjacent docs whenever the first doc does not end in whitespace. ensure_whitespace (bool): Insert a space between two adjacent docs
attrs (list): Optional list of attribute ID ints or attribute name strings. whenever the first doc does not end in whitespace.
RETURNS (Doc): A doc that contains the concatenated docs, or None if no docs were given. attrs (list): Optional list of attribute ID ints or attribute name
strings.
exclude (Iterable[str]): Doc attributes to exclude. Supported
attributes: `spans`, `tensor`, `user_data`.
RETURNS (Doc): A doc that contains the concatenated docs, or None if no
docs were given.
DOCS: https://spacy.io/api/doc#from_docs DOCS: https://spacy.io/api/doc#from_docs
""" """
@ -1145,31 +1150,33 @@ cdef class Doc:
concat_words.extend(t.text for t in doc) concat_words.extend(t.text for t in doc)
concat_spaces.extend(bool(t.whitespace_) for t in doc) concat_spaces.extend(bool(t.whitespace_) for t in doc)
for key, value in doc.user_data.items(): if "user_data" not in exclude:
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": for key, value in doc.user_data.items():
data_type, name, start, end = key if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
if start is not None or end is not None: data_type, name, start, end = key
start += char_offset if start is not None or end is not None:
if end is not None: start += char_offset
end += char_offset if end is not None:
concat_user_data[(data_type, name, start, end)] = copy.copy(value) end += char_offset
concat_user_data[(data_type, name, start, end)] = copy.copy(value)
else:
warnings.warn(Warnings.W101.format(name=name))
else: else:
warnings.warn(Warnings.W101.format(name=name)) warnings.warn(Warnings.W102.format(key=key, value=value))
else: if "spans" not in exclude:
warnings.warn(Warnings.W102.format(key=key, value=value)) for key in doc.spans:
for key in doc.spans: # if a spans key is in any doc, include it in the merged doc
# if a spans key is in any doc, include it in the merged doc # even if it is empty
# even if it is empty if key not in concat_spans:
if key not in concat_spans: concat_spans[key] = []
concat_spans[key] = [] for span in doc.spans[key]:
for span in doc.spans[key]: concat_spans[key].append((
concat_spans[key].append(( span.start_char + char_offset,
span.start_char + char_offset, span.end_char + char_offset,
span.end_char + char_offset, span.label,
span.label, span.kb_id,
span.kb_id, span.text, # included as a check
span.text, # included as a check ))
))
char_offset += len(doc.text) char_offset += len(doc.text)
if len(doc) > 0 and ensure_whitespace and not doc[-1].is_space and not bool(doc[-1].whitespace_): if len(doc) > 0 and ensure_whitespace and not doc[-1].is_space and not bool(doc[-1].whitespace_):
char_offset += 1 char_offset += 1
@ -1210,6 +1217,10 @@ cdef class Doc:
else: else:
raise ValueError(Errors.E873.format(key=key, text=text)) raise ValueError(Errors.E873.format(key=key, text=text))
if "tensor" not in exclude and any(len(doc) for doc in docs):
ops = get_current_ops()
concat_doc.tensor = ops.xp.vstack([ops.asarray(doc.tensor) for doc in docs if len(doc)])
return concat_doc return concat_doc
def get_lca_matrix(self): def get_lca_matrix(self):

View File

@ -34,7 +34,7 @@ Construct a `Doc` object. The most common way to get a `Doc` object is via the
| Name | Description | | Name | Description |
| ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `vocab` | A storage container for lexical types. ~~Vocab~~ | | `vocab` | A storage container for lexical types. ~~Vocab~~ |
| `words` | A list of strings or integer hash values to add to the document as words. ~~Optional[List[Union[str,int]]]~~ | | `words` | A list of strings or integer hash values to add to the document as words. ~~Optional[List[Union[str,int]]]~~ |
| `spaces` | A list of boolean values indicating whether each word has a subsequent space. Must have the same length as `words`, if specified. Defaults to a sequence of `True`. ~~Optional[List[bool]]~~ | | `spaces` | A list of boolean values indicating whether each word has a subsequent space. Must have the same length as `words`, if specified. Defaults to a sequence of `True`. ~~Optional[List[bool]]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `user\_data` | Optional extra data to attach to the Doc. ~~Dict~~ | | `user\_data` | Optional extra data to attach to the Doc. ~~Dict~~ |
@ -304,7 +304,8 @@ ancestor is found, e.g. if span excludes a necessary ancestor.
## Doc.has_annotation {#has_annotation tag="method"} ## Doc.has_annotation {#has_annotation tag="method"}
Check whether the doc contains annotation on a [`Token` attribute](/api/token#attributes). Check whether the doc contains annotation on a
[`Token` attribute](/api/token#attributes).
<Infobox title="Changed in v3.0" variant="warning"> <Infobox title="Changed in v3.0" variant="warning">
@ -398,12 +399,14 @@ Concatenate multiple `Doc` objects to form a new one. Raises an error if the
> [str(ent) for doc in docs for ent in doc.ents] > [str(ent) for doc in docs for ent in doc.ents]
> ``` > ```
| Name | Description | | Name | Description |
| ------------------- | ----------------------------------------------------------------------------------------------------------------- | | -------------------------------------- | ----------------------------------------------------------------------------------------------------------------- |
| `docs` | A list of `Doc` objects. ~~List[Doc]~~ | | `docs` | A list of `Doc` objects. ~~List[Doc]~~ |
| `ensure_whitespace` | Insert a space between two adjacent docs whenever the first doc does not end in whitespace. ~~bool~~ | | `ensure_whitespace` | Insert a space between two adjacent docs whenever the first doc does not end in whitespace. ~~bool~~ |
| `attrs` | Optional list of attribute ID ints or attribute name strings. ~~Optional[List[Union[str, int]]]~~ | | `attrs` | Optional list of attribute ID ints or attribute name strings. ~~Optional[List[Union[str, int]]]~~ |
| **RETURNS** | The new `Doc` object that is containing the other docs or `None`, if `docs` is empty or `None`. ~~Optional[Doc]~~ | | _keyword-only_ | |
| `exclude` <Tag variant="new">3.3</Tag> | String names of Doc attributes to exclude. Supported: `spans`, `tensor`, `user_data`. ~~Iterable[str]~~ |
| **RETURNS** | The new `Doc` object that is containing the other docs or `None`, if `docs` is empty or `None`. ~~Optional[Doc]~~ |
## Doc.to_disk {#to_disk tag="method" new="2"} ## Doc.to_disk {#to_disk tag="method" new="2"}