Fix/update extension copying in Span.as_doc and Doc.from_docs (#7574)

* Adjust custom extension data when copying user data in `Span.as_doc()`
* Restrict `Doc.from_docs()` to adjusting offsets for custom extension
data
  * Update test to use extension
  * (Duplicate bug fix for character offset from #7497)
This commit is contained in:
Adriane Boyd 2021-03-30 09:49:12 +02:00 committed by GitHub
parent af07fc3bc1
commit 27a48f2802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 11 deletions

View File

@ -6,12 +6,14 @@ import logging
import mock import mock
from spacy.lang.xx import MultiLanguage from spacy.lang.xx import MultiLanguage
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span, Token
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.lexeme import Lexeme from spacy.lexeme import Lexeme
from spacy.lang.en import English from spacy.lang.en import English
from spacy.attrs import ENT_TYPE, ENT_IOB, SENT_START, HEAD, DEP, MORPH from spacy.attrs import ENT_TYPE, ENT_IOB, SENT_START, HEAD, DEP, MORPH
from .test_underscore import clean_underscore # noqa: F401
def test_doc_api_init(en_vocab): def test_doc_api_init(en_vocab):
words = ["a", "b", "c", "d"] words = ["a", "b", "c", "d"]
@ -347,6 +349,7 @@ def test_doc_from_array_morph(en_vocab):
assert [str(t.morph) for t in doc] == [str(t.morph) for t in new_doc] assert [str(t.morph) for t in doc] == [str(t.morph) for t in new_doc]
@pytest.mark.usefixtures("clean_underscore")
def test_doc_api_from_docs(en_tokenizer, de_tokenizer): def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
en_texts = ["Merging the docs is fun.", "", "They don't think alike."] en_texts = ["Merging the docs is fun.", "", "They don't think alike."]
en_texts_without_empty = [t for t in en_texts if len(t)] en_texts_without_empty = [t for t in en_texts if len(t)]
@ -355,10 +358,10 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
en_docs[0].spans["group"] = [en_docs[0][1:4]] en_docs[0].spans["group"] = [en_docs[0][1:4]]
en_docs[2].spans["group"] = [en_docs[2][1:4]] en_docs[2].spans["group"] = [en_docs[2][1:4]]
span_group_texts = sorted([en_docs[0][1:4].text, en_docs[2][1:4].text]) span_group_texts = sorted([en_docs[0][1:4].text, en_docs[2][1:4].text])
docs_idx = en_texts[0].index("docs")
de_doc = de_tokenizer(de_text) de_doc = de_tokenizer(de_text)
expected = (True, None, None, None) Token.set_extension("is_ambiguous", default=False)
en_docs[0].user_data[("._.", "is_ambiguous", docs_idx, None)] = expected en_docs[0][2]._.is_ambiguous = True # docs
en_docs[2][3]._.is_ambiguous = True # think
assert Doc.from_docs([]) is None assert Doc.from_docs([]) is None
assert de_doc is not Doc.from_docs([de_doc]) assert de_doc is not Doc.from_docs([de_doc])
assert str(de_doc) == str(Doc.from_docs([de_doc])) assert str(de_doc) == str(Doc.from_docs([de_doc]))
@ -375,11 +378,10 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
en_docs_tokens = [t for doc in en_docs for t in doc] en_docs_tokens = [t for doc in en_docs for t in doc]
assert len(m_doc) == len(en_docs_tokens) assert len(m_doc) == len(en_docs_tokens)
think_idx = len(en_texts[0]) + 1 + en_texts[2].index("think") think_idx = len(en_texts[0]) + 1 + en_texts[2].index("think")
assert m_doc[2]._.is_ambiguous == True
assert m_doc[9].idx == think_idx assert m_doc[9].idx == think_idx
with pytest.raises(AttributeError): assert m_doc[9]._.is_ambiguous == True
# not callable, because it was not set via set_extension assert not any([t._.is_ambiguous for t in m_doc[3:8]])
m_doc[2]._.is_ambiguous
assert len(m_doc.user_data) == len(en_docs[0].user_data) # but it's there
assert "group" in m_doc.spans assert "group" in m_doc.spans
assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]]) assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]])

View File

@ -1,9 +1,11 @@
import pytest import pytest
from spacy.attrs import ORTH, LENGTH from spacy.attrs import ORTH, LENGTH
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span, Token
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.util import filter_spans from spacy.util import filter_spans
from .test_underscore import clean_underscore # noqa: F401
@pytest.fixture @pytest.fixture
def doc(en_tokenizer): def doc(en_tokenizer):
@ -219,11 +221,14 @@ def test_span_as_doc(doc):
assert span_doc[0].idx == 0 assert span_doc[0].idx == 0
@pytest.mark.usefixtures("clean_underscore")
def test_span_as_doc_user_data(doc): def test_span_as_doc_user_data(doc):
"""Test that the user_data can be preserved (but not by default). """ """Test that the user_data can be preserved (but not by default). """
my_key = "my_info" my_key = "my_info"
my_value = 342 my_value = 342
doc.user_data[my_key] = my_value doc.user_data[my_key] = my_value
Token.set_extension("is_x", default=False)
doc[7]._.is_x = True
span = doc[4:10] span = doc[4:10]
span_doc_with = span.as_doc(copy_user_data=True) span_doc_with = span.as_doc(copy_user_data=True)
@ -232,6 +237,12 @@ def test_span_as_doc_user_data(doc):
assert doc.user_data.get(my_key, None) is my_value assert doc.user_data.get(my_key, None) is my_value
assert span_doc_with.user_data.get(my_key, None) is my_value assert span_doc_with.user_data.get(my_key, None) is my_value
assert span_doc_without.user_data.get(my_key, None) is None assert span_doc_without.user_data.get(my_key, None) is None
for i in range(len(span_doc_with)):
if i != 3:
assert span_doc_with[i]._.is_x is False
else:
assert span_doc_with[i]._.is_x is True
assert not any([t._.is_x for t in span_doc_without])
def test_span_string_label_kb_id(doc): def test_span_string_label_kb_id(doc):

View File

@ -1127,7 +1127,7 @@ cdef class Doc:
concat_spaces.extend(bool(t.whitespace_) for t in doc) concat_spaces.extend(bool(t.whitespace_) for t in doc)
for key, value in doc.user_data.items(): for key, value in doc.user_data.items():
if isinstance(key, tuple) and len(key) == 4: if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
data_type, name, start, end = key data_type, name, start, end = key
if start is not None or end is not None: if start is not None or end is not None:
start += char_offset start += char_offset

View File

@ -6,6 +6,7 @@ from libc.math cimport sqrt
import numpy import numpy
from thinc.api import get_array_module from thinc.api import get_array_module
import warnings import warnings
import copy
from .doc cimport token_by_start, token_by_end, get_token_attr, _get_lca_matrix from .doc cimport token_by_start, token_by_end, get_token_attr, _get_lca_matrix
from ..structs cimport TokenC, LexemeC from ..structs cimport TokenC, LexemeC
@ -241,7 +242,19 @@ cdef class Span:
if cat_start == self.start_char and cat_end == self.end_char: if cat_start == self.start_char and cat_end == self.end_char:
doc.cats[cat_label] = value doc.cats[cat_label] = value
if copy_user_data: if copy_user_data:
doc.user_data = self.doc.user_data user_data = {}
char_offset = self.start_char
for key, value in self.doc.user_data.items():
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
data_type, name, start, end = key
if start is not None or end is not None:
start -= char_offset
if end is not None:
end -= char_offset
user_data[(data_type, name, start, end)] = copy.copy(value)
else:
user_data[key] = copy.copy(value)
doc.user_data = user_data
return doc return doc
def _fix_dep_copy(self, attrs, array): def _fix_dep_copy(self, attrs, array):