mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Merging multiple docs into one (#5032)
* Add static method to Doc to allow merging of multiple docs. * Add error description for the error that occurs if docs with different vocabs (from different languages) are merged in Doc.from_docs(). * Add test for Doc.from_docs() implementation. * Fix using numpy's concatenate in Doc.from_docs. * Replace typing's type annotations in from_docs. * Simply remove type annotations in from_docs. * Add documentation for Doc.from_docs to api. * Simplify from_docs, its test and the api doc for codebase consistency. * Fix merging of Doc objects that end with whitespaces (Achieved by simply not setting the SPACY attribute on whitespace tokens). Remove two unnecessary imports of attributes. * Add merging of user data from Doc objects in from_docs. Add user data test case to corresponding test. Add applicable warning messages. * Fix incorrect setting of tokens idx by using concatenated spaces (again). Add test case to corresponding test. * Add MORPH to attrs * Update warnings calls * Remove out-dated error from merge * Rename space_delimiter to ensure_whitespace Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
41b65fd0f8
commit
e4dcac4a4b
|
@ -159,6 +159,8 @@ class Warnings(object):
|
||||||
W100 = ("Skipping unsupported morphological feature(s): '{feature}'. "
|
W100 = ("Skipping unsupported morphological feature(s): '{feature}'. "
|
||||||
"Provide features as a dict {{\"Field1\": \"Value1,Value2\"}} or "
|
"Provide features as a dict {{\"Field1\": \"Value1,Value2\"}} or "
|
||||||
"string \"Field1=Value1,Value2|Field2=Value3\".")
|
"string \"Field1=Value1,Value2|Field2=Value3\".")
|
||||||
|
W101 = ("Skipping `Doc` custom extension '{name}' while merging docs.")
|
||||||
|
W102 = ("Skipping unsupported user data '{key}: {value}' while merging docs.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
@ -593,6 +595,8 @@ class Errors(object):
|
||||||
E997 = ("Tokenizer special cases are not allowed to modify the text. "
|
E997 = ("Tokenizer special cases are not allowed to modify the text. "
|
||||||
"This would map '{chunk}' to '{orth}' given token attributes "
|
"This would map '{chunk}' to '{orth}' given token attributes "
|
||||||
"'{token_attrs}'.")
|
"'{token_attrs}'.")
|
||||||
|
E999 = ("Unable to merge the `Doc` objects because they do not all share "
|
||||||
|
"the same `Vocab`.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -303,6 +303,60 @@ def test_doc_from_array_sent_starts(en_vocab):
|
||||||
assert new_doc.is_parsed
|
assert new_doc.is_parsed
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
|
||||||
|
en_texts = ["Merging the docs is fun.", "They don't think alike."]
|
||||||
|
de_text = "Wie war die Frage?"
|
||||||
|
en_docs = [en_tokenizer(text) for text in en_texts]
|
||||||
|
docs_idx = en_texts[0].index('docs')
|
||||||
|
de_doc = de_tokenizer(de_text)
|
||||||
|
en_docs[0].user_data[("._.", "is_ambiguous", docs_idx, None)] = (True, None, None, None)
|
||||||
|
|
||||||
|
assert Doc.from_docs([]) is None
|
||||||
|
|
||||||
|
assert de_doc is not Doc.from_docs([de_doc])
|
||||||
|
assert str(de_doc) == str(Doc.from_docs([de_doc]))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Doc.from_docs(en_docs + [de_doc])
|
||||||
|
|
||||||
|
m_doc = Doc.from_docs(en_docs)
|
||||||
|
assert len(en_docs) == len(list(m_doc.sents))
|
||||||
|
assert len(str(m_doc)) > len(en_texts[0]) + len(en_texts[1])
|
||||||
|
assert str(m_doc) == " ".join(en_texts)
|
||||||
|
p_token = m_doc[len(en_docs[0])-1]
|
||||||
|
assert p_token.text == "." and bool(p_token.whitespace_)
|
||||||
|
en_docs_tokens = [t for doc in en_docs for t in doc]
|
||||||
|
assert len(m_doc) == len(en_docs_tokens)
|
||||||
|
think_idx = len(en_texts[0]) + 1 + en_texts[1].index('think')
|
||||||
|
assert m_doc[9].idx == think_idx
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
not_available = m_doc[2]._.is_ambiguous # not callable, because it was not set via set_extension
|
||||||
|
assert len(m_doc.user_data) == len(en_docs[0].user_data) # but it's there
|
||||||
|
|
||||||
|
m_doc = Doc.from_docs(en_docs, ensure_whitespace=False)
|
||||||
|
assert len(en_docs) == len(list(m_doc.sents))
|
||||||
|
assert len(str(m_doc)) == len(en_texts[0]) + len(en_texts[1])
|
||||||
|
assert str(m_doc) == "".join(en_texts)
|
||||||
|
p_token = m_doc[len(en_docs[0]) - 1]
|
||||||
|
assert p_token.text == "." and not bool(p_token.whitespace_)
|
||||||
|
en_docs_tokens = [t for doc in en_docs for t in doc]
|
||||||
|
assert len(m_doc) == len(en_docs_tokens)
|
||||||
|
think_idx = len(en_texts[0]) + 0 + en_texts[1].index('think')
|
||||||
|
assert m_doc[9].idx == think_idx
|
||||||
|
|
||||||
|
m_doc = Doc.from_docs(en_docs, attrs=['lemma', 'length', 'pos'])
|
||||||
|
with pytest.raises(ValueError): # important attributes from sentenziser or parser are missing
|
||||||
|
assert list(m_doc.sents)
|
||||||
|
assert len(str(m_doc)) > len(en_texts[0]) + len(en_texts[1])
|
||||||
|
assert str(m_doc) == " ".join(en_texts) # space delimiter considered, although spacy attribute was missing
|
||||||
|
p_token = m_doc[len(en_docs[0]) - 1]
|
||||||
|
assert p_token.text == "." and bool(p_token.whitespace_)
|
||||||
|
en_docs_tokens = [t for doc in en_docs for t in doc]
|
||||||
|
assert len(m_doc) == len(en_docs_tokens)
|
||||||
|
think_idx = len(en_texts[0]) + 1 + en_texts[1].index('think')
|
||||||
|
assert m_doc[9].idx == think_idx
|
||||||
|
|
||||||
|
|
||||||
def test_doc_lang(en_vocab):
|
def test_doc_lang(en_vocab):
|
||||||
doc = Doc(en_vocab, words=["Hello", "world"])
|
doc = Doc(en_vocab, words=["Hello", "world"])
|
||||||
assert doc.lang_ == "en"
|
assert doc.lang_ == "en"
|
||||||
|
|
|
@ -5,6 +5,7 @@ from libc.string cimport memcpy, memset
|
||||||
from libc.math cimport sqrt
|
from libc.math cimport sqrt
|
||||||
from libc.stdint cimport int32_t, uint64_t
|
from libc.stdint cimport int32_t, uint64_t
|
||||||
|
|
||||||
|
import copy
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import numpy
|
import numpy
|
||||||
import numpy.linalg
|
import numpy.linalg
|
||||||
|
@ -24,7 +25,7 @@ from ..attrs cimport LENGTH, POS, LEMMA, TAG, MORPH, DEP, HEAD, SPACY, ENT_IOB
|
||||||
from ..attrs cimport ENT_TYPE, ENT_ID, ENT_KB_ID, SENT_START, IDX, attr_id_t
|
from ..attrs cimport ENT_TYPE, ENT_ID, ENT_KB_ID, SENT_START, IDX, attr_id_t
|
||||||
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
||||||
|
|
||||||
from ..attrs import intify_attrs, IDS
|
from ..attrs import intify_attr, intify_attrs, IDS
|
||||||
from ..util import normalize_slice
|
from ..util import normalize_slice
|
||||||
from ..compat import copy_reg, pickle
|
from ..compat import copy_reg, pickle
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
|
@ -806,7 +807,7 @@ cdef class Doc:
|
||||||
attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
||||||
for id_ in attrs]
|
for id_ in attrs]
|
||||||
if array.dtype != numpy.uint64:
|
if array.dtype != numpy.uint64:
|
||||||
warnings.warn(Warnings.W028.format(type=array.dtype))
|
warnings.warn(Warnings.W101.format(type=array.dtype))
|
||||||
|
|
||||||
if SENT_START in attrs and HEAD in attrs:
|
if SENT_START in attrs and HEAD in attrs:
|
||||||
raise ValueError(Errors.E032)
|
raise ValueError(Errors.E032)
|
||||||
|
@ -882,6 +883,87 @@ cdef class Doc:
|
||||||
set_children_from_heads(self.c, length)
|
set_children_from_heads(self.c, length)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_docs(docs, ensure_whitespace=True, attrs=None):
|
||||||
|
"""Concatenate multiple Doc objects to form a new one. Raises an error if the `Doc` objects do not all share
|
||||||
|
the same `Vocab`.
|
||||||
|
|
||||||
|
docs (list): A list of Doc objects.
|
||||||
|
ensure_whitespace (bool): Insert a space between two adjacent docs whenever the first doc does not end in whitespace.
|
||||||
|
attrs (list): Optional list of attribute ID ints or attribute name strings.
|
||||||
|
RETURNS (Doc): A doc that contains the concatenated docs, or None if no docs were given.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/doc#from_docs
|
||||||
|
"""
|
||||||
|
if not docs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
vocab = {doc.vocab for doc in docs}
|
||||||
|
if len(vocab) > 1:
|
||||||
|
raise ValueError(Errors.E999)
|
||||||
|
(vocab,) = vocab
|
||||||
|
|
||||||
|
if attrs is None:
|
||||||
|
attrs = [LEMMA, NORM]
|
||||||
|
if all(doc.is_nered for doc in docs):
|
||||||
|
attrs.extend([ENT_IOB, ENT_KB_ID, ENT_TYPE])
|
||||||
|
# TODO: separate for is_morphed?
|
||||||
|
if all(doc.is_tagged for doc in docs):
|
||||||
|
attrs.extend([TAG, POS, MORPH])
|
||||||
|
if all(doc.is_parsed for doc in docs):
|
||||||
|
attrs.extend([HEAD, DEP])
|
||||||
|
else:
|
||||||
|
attrs.append(SENT_START)
|
||||||
|
else:
|
||||||
|
if any(isinstance(attr, str) for attr in attrs): # resolve attribute names
|
||||||
|
attrs = [intify_attr(attr) for attr in attrs] # intify_attr returns None for invalid attrs
|
||||||
|
attrs = list(attr for attr in set(attrs) if attr) # filter duplicates, remove None if present
|
||||||
|
if SPACY not in attrs:
|
||||||
|
attrs.append(SPACY)
|
||||||
|
|
||||||
|
concat_words = []
|
||||||
|
concat_spaces = []
|
||||||
|
concat_user_data = {}
|
||||||
|
char_offset = 0
|
||||||
|
for doc in docs:
|
||||||
|
concat_words.extend(t.text for t in doc)
|
||||||
|
concat_spaces.extend(bool(t.whitespace_) for t in doc)
|
||||||
|
|
||||||
|
for key, value in doc.user_data.items():
|
||||||
|
if isinstance(key, tuple) and len(key) == 4:
|
||||||
|
data_type, name, start, end = key
|
||||||
|
if start is not None or end is not None:
|
||||||
|
start += char_offset
|
||||||
|
if end is not None:
|
||||||
|
end += char_offset
|
||||||
|
concat_user_data[(data_type, name, start, end)] = copy.copy(value)
|
||||||
|
else:
|
||||||
|
warnings.warn(Warnings.W101.format(name=name))
|
||||||
|
else:
|
||||||
|
warnings.warn(Warnings.W102.format(key=key, value=value))
|
||||||
|
char_offset += len(doc.text) if not ensure_whitespace or doc[-1].is_space else len(doc.text) + 1
|
||||||
|
|
||||||
|
arrays = [doc.to_array(attrs) for doc in docs]
|
||||||
|
|
||||||
|
if ensure_whitespace:
|
||||||
|
spacy_index = attrs.index(SPACY)
|
||||||
|
for i, array in enumerate(arrays[:-1]):
|
||||||
|
if len(array) > 0 and not docs[i][-1].is_space:
|
||||||
|
array[-1][spacy_index] = 1
|
||||||
|
token_offset = -1
|
||||||
|
for doc in docs[:-1]:
|
||||||
|
token_offset += len(doc)
|
||||||
|
if not doc[-1].is_space:
|
||||||
|
concat_spaces[token_offset] = True
|
||||||
|
|
||||||
|
concat_array = numpy.concatenate(arrays)
|
||||||
|
|
||||||
|
concat_doc = Doc(vocab, words=concat_words, spaces=concat_spaces, user_data=concat_user_data)
|
||||||
|
|
||||||
|
concat_doc.from_array(attrs, concat_array)
|
||||||
|
|
||||||
|
return concat_doc
|
||||||
|
|
||||||
def get_lca_matrix(self):
|
def get_lca_matrix(self):
|
||||||
"""Calculates a matrix of Lowest Common Ancestors (LCA) for a given
|
"""Calculates a matrix of Lowest Common Ancestors (LCA) for a given
|
||||||
`Doc`, where LCA[i, j] is the index of the lowest common ancestor among
|
`Doc`, where LCA[i, j] is the index of the lowest common ancestor among
|
||||||
|
|
|
@ -349,6 +349,33 @@ array of attributes.
|
||||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||||
| **RETURNS** | `Doc` | Itself. |
|
| **RETURNS** | `Doc` | Itself. |
|
||||||
|
|
||||||
|
|
||||||
|
## Doc.from_docs {#from_docs tag="staticmethod"}
|
||||||
|
|
||||||
|
Concatenate multiple `Doc` objects to form a new one. Raises an error if the `Doc` objects do not all share the same `Vocab`.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> from spacy.tokens import Doc
|
||||||
|
> texts = ["London is the capital of the United Kingdom.",
|
||||||
|
> "The River Thames flows through London.",
|
||||||
|
> "The famous Tower Bridge crosses the River Thames."]
|
||||||
|
> docs = list(nlp.pipe(texts))
|
||||||
|
> c_doc = Doc.from_docs(docs)
|
||||||
|
> assert str(c_doc) == " ".join(texts)
|
||||||
|
> assert len(list(c_doc.sents)) == len(docs)
|
||||||
|
> assert [str(ent) for ent in c_doc.ents] == \
|
||||||
|
> [str(ent) for doc in docs for ent in doc.ents]
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ------------------- | ----- | ----------------------------------------------------------------------------------------------- |
|
||||||
|
| `docs` | list | A list of `Doc` objects. |
|
||||||
|
| `ensure_whitespace` | bool | Insert a space between two adjacent docs whenever the first doc does not end in whitespace. |
|
||||||
|
| `attrs` | list | Optional list of attribute ID ints or attribute name strings. |
|
||||||
|
| **RETURNS** | `Doc` | The new `Doc` object that is containing the other docs or `None`, if `docs` is empty or `None`. |
|
||||||
|
|
||||||
## Doc.to_disk {#to_disk tag="method" new="2"}
|
## Doc.to_disk {#to_disk tag="method" new="2"}
|
||||||
|
|
||||||
Save the current state to a directory.
|
Save the current state to a directory.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user