Fix/update extension copying in Span.as_doc and Doc.from_docs (#7574)

* Adjust custom extension data when copying user data in `Span.as_doc()`
* Restrict `Doc.from_docs()` to adjusting offsets for custom extension
data
  * Update test to use extension
  * (Duplicate bug fix for character offset from #7497)
This commit is contained in:
Adriane Boyd 2021-03-30 09:49:12 +02:00 committed by GitHub
parent af07fc3bc1
commit 27a48f2802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 11 deletions

View File

@ -6,12 +6,14 @@ import logging
import mock
from spacy.lang.xx import MultiLanguage
from spacy.tokens import Doc, Span
from spacy.tokens import Doc, Span, Token
from spacy.vocab import Vocab
from spacy.lexeme import Lexeme
from spacy.lang.en import English
from spacy.attrs import ENT_TYPE, ENT_IOB, SENT_START, HEAD, DEP, MORPH
from .test_underscore import clean_underscore # noqa: F401
def test_doc_api_init(en_vocab):
words = ["a", "b", "c", "d"]
@ -347,6 +349,7 @@ def test_doc_from_array_morph(en_vocab):
assert [str(t.morph) for t in doc] == [str(t.morph) for t in new_doc]
@pytest.mark.usefixtures("clean_underscore")
def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
en_texts = ["Merging the docs is fun.", "", "They don't think alike."]
en_texts_without_empty = [t for t in en_texts if len(t)]
@ -355,10 +358,10 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
en_docs[0].spans["group"] = [en_docs[0][1:4]]
en_docs[2].spans["group"] = [en_docs[2][1:4]]
span_group_texts = sorted([en_docs[0][1:4].text, en_docs[2][1:4].text])
docs_idx = en_texts[0].index("docs")
de_doc = de_tokenizer(de_text)
expected = (True, None, None, None)
en_docs[0].user_data[("._.", "is_ambiguous", docs_idx, None)] = expected
Token.set_extension("is_ambiguous", default=False)
en_docs[0][2]._.is_ambiguous = True # docs
en_docs[2][3]._.is_ambiguous = True # think
assert Doc.from_docs([]) is None
assert de_doc is not Doc.from_docs([de_doc])
assert str(de_doc) == str(Doc.from_docs([de_doc]))
@ -375,11 +378,10 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
en_docs_tokens = [t for doc in en_docs for t in doc]
assert len(m_doc) == len(en_docs_tokens)
think_idx = len(en_texts[0]) + 1 + en_texts[2].index("think")
assert m_doc[2]._.is_ambiguous == True
assert m_doc[9].idx == think_idx
with pytest.raises(AttributeError):
# not callable, because it was not set via set_extension
m_doc[2]._.is_ambiguous
assert len(m_doc.user_data) == len(en_docs[0].user_data) # but it's there
assert m_doc[9]._.is_ambiguous == True
assert not any([t._.is_ambiguous for t in m_doc[3:8]])
assert "group" in m_doc.spans
assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]])

View File

@ -1,9 +1,11 @@
import pytest
from spacy.attrs import ORTH, LENGTH
from spacy.tokens import Doc, Span
from spacy.tokens import Doc, Span, Token
from spacy.vocab import Vocab
from spacy.util import filter_spans
from .test_underscore import clean_underscore # noqa: F401
@pytest.fixture
def doc(en_tokenizer):
@ -219,11 +221,14 @@ def test_span_as_doc(doc):
assert span_doc[0].idx == 0
@pytest.mark.usefixtures("clean_underscore")
def test_span_as_doc_user_data(doc):
"""Test that the user_data can be preserved (but not by default). """
my_key = "my_info"
my_value = 342
doc.user_data[my_key] = my_value
Token.set_extension("is_x", default=False)
doc[7]._.is_x = True
span = doc[4:10]
span_doc_with = span.as_doc(copy_user_data=True)
@ -232,6 +237,12 @@ def test_span_as_doc_user_data(doc):
assert doc.user_data.get(my_key, None) is my_value
assert span_doc_with.user_data.get(my_key, None) is my_value
assert span_doc_without.user_data.get(my_key, None) is None
for i in range(len(span_doc_with)):
if i != 3:
assert span_doc_with[i]._.is_x is False
else:
assert span_doc_with[i]._.is_x is True
assert not any([t._.is_x for t in span_doc_without])
def test_span_string_label_kb_id(doc):

View File

@ -1127,7 +1127,7 @@ cdef class Doc:
concat_spaces.extend(bool(t.whitespace_) for t in doc)
for key, value in doc.user_data.items():
if isinstance(key, tuple) and len(key) == 4:
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
data_type, name, start, end = key
if start is not None or end is not None:
start += char_offset

View File

@ -6,6 +6,7 @@ from libc.math cimport sqrt
import numpy
from thinc.api import get_array_module
import warnings
import copy
from .doc cimport token_by_start, token_by_end, get_token_attr, _get_lca_matrix
from ..structs cimport TokenC, LexemeC
@ -241,7 +242,19 @@ cdef class Span:
if cat_start == self.start_char and cat_end == self.end_char:
doc.cats[cat_label] = value
if copy_user_data:
doc.user_data = self.doc.user_data
user_data = {}
char_offset = self.start_char
for key, value in self.doc.user_data.items():
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
data_type, name, start, end = key
if start is not None or end is not None:
start -= char_offset
if end is not None:
end -= char_offset
user_data[(data_type, name, start, end)] = copy.copy(value)
else:
user_data[key] = copy.copy(value)
doc.user_data = user_data
return doc
def _fix_dep_copy(self, attrs, array):