Merge doc.spans in Doc.from_docs() (#7497)

Merge data from `doc.spans` in `Doc.from_docs()`.

* Fix internal character offset set when merging empty docs (only
affects tokens and spans in `user_data` if an empty doc is in the list
of docs)
This commit is contained in:
Adriane Boyd 2021-03-29 13:34:01 +02:00 committed by GitHub
parent d59f968d08
commit 139f655f34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 2 deletions

View File

@ -497,6 +497,9 @@ class Errors:
E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.")
# New errors added in v3.x
E873 = ("Unable to merge a span from doc.spans with key '{key}' and text "
"'{text}'. This is likely a bug in spaCy, so feel free to open an "
"issue: https://github.com/explosion/spaCy/issues")
E874 = ("Could not initialize the tok2vec model from component "
"'{component}' and layer '{layer}'.")
E875 = ("To use the PretrainVectors objective, make sure that static vectors are loaded. "

View File

@ -352,6 +352,9 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
en_texts_without_empty = [t for t in en_texts if len(t)]
de_text = "Wie war die Frage?"
en_docs = [en_tokenizer(text) for text in en_texts]
en_docs[0].spans["group"] = [en_docs[0][1:4]]
en_docs[2].spans["group"] = [en_docs[2][1:4]]
span_group_texts = sorted([en_docs[0][1:4].text, en_docs[2][1:4].text])
docs_idx = en_texts[0].index("docs")
de_doc = de_tokenizer(de_text)
expected = (True, None, None, None)
@ -377,6 +380,8 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
# not callable, because it was not set via set_extension
m_doc[2]._.is_ambiguous
assert len(m_doc.user_data) == len(en_docs[0].user_data) # but it's there
assert "group" in m_doc.spans
assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]])
m_doc = Doc.from_docs(en_docs, ensure_whitespace=False)
assert len(en_texts_without_empty) == len(list(m_doc.sents))
@ -388,6 +393,8 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
assert len(m_doc) == len(en_docs_tokens)
think_idx = len(en_texts[0]) + 0 + en_texts[2].index("think")
assert m_doc[9].idx == think_idx
assert "group" in m_doc.spans
assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]])
m_doc = Doc.from_docs(en_docs, attrs=["lemma", "length", "pos"])
assert len(str(m_doc)) > len(en_texts[0]) + len(en_texts[1])
@ -399,6 +406,8 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
assert len(m_doc) == len(en_docs_tokens)
think_idx = len(en_texts[0]) + 1 + en_texts[2].index("think")
assert m_doc[9].idx == think_idx
assert "group" in m_doc.spans
assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]])
def test_doc_api_from_docs_ents(en_tokenizer):

View File

@ -6,7 +6,7 @@ from libc.math cimport sqrt
from libc.stdint cimport int32_t, uint64_t
import copy
from collections import Counter
from collections import Counter, defaultdict
from enum import Enum
import itertools
import numpy
@ -1120,6 +1120,7 @@ cdef class Doc:
concat_words = []
concat_spaces = []
concat_user_data = {}
concat_spans = defaultdict(list)
char_offset = 0
for doc in docs:
concat_words.extend(t.text for t in doc)
@ -1137,8 +1138,17 @@ cdef class Doc:
warnings.warn(Warnings.W101.format(name=name))
else:
warnings.warn(Warnings.W102.format(key=key, value=value))
for key in doc.spans:
for span in doc.spans[key]:
concat_spans[key].append((
span.start_char + char_offset,
span.end_char + char_offset,
span.label,
span.kb_id,
span.text, # included as a check
))
char_offset += len(doc.text)
if ensure_whitespace and not (len(doc) > 0 and doc[-1].is_space):
if len(doc) > 0 and ensure_whitespace and not doc[-1].is_space:
char_offset += 1
arrays = [doc.to_array(attrs) for doc in docs]
@ -1160,6 +1170,22 @@ cdef class Doc:
concat_doc.from_array(attrs, concat_array)
for key in concat_spans:
if key not in concat_doc.spans:
concat_doc.spans[key] = []
for span_tuple in concat_spans[key]:
span = concat_doc.char_span(
span_tuple[0],
span_tuple[1],
label=span_tuple[2],
kb_id=span_tuple[3],
)
text = span_tuple[4]
if span is not None and span.text == text:
concat_doc.spans[key].append(span)
else:
raise ValueError(Errors.E873.format(key=key, text=text))
return concat_doc
def get_lca_matrix(self):