mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Fix/update extension copying in Span.as_doc and Doc.from_docs (#7574)
* Adjust custom extension data when copying user data in `Span.as_doc()` * Restrict `Doc.from_docs()` to adjusting offsets for custom extension data * Update test to use extension * (Duplicate bug fix for character offset from #7497)
This commit is contained in:
		
							parent
							
								
									af07fc3bc1
								
							
						
					
					
						commit
						27a48f2802
					
				|  | @ -6,12 +6,14 @@ import logging | |||
| import mock | ||||
| 
 | ||||
| from spacy.lang.xx import MultiLanguage | ||||
| from spacy.tokens import Doc, Span | ||||
| from spacy.tokens import Doc, Span, Token | ||||
| from spacy.vocab import Vocab | ||||
| from spacy.lexeme import Lexeme | ||||
| from spacy.lang.en import English | ||||
| from spacy.attrs import ENT_TYPE, ENT_IOB, SENT_START, HEAD, DEP, MORPH | ||||
| 
 | ||||
| from .test_underscore import clean_underscore  # noqa: F401 | ||||
| 
 | ||||
| 
 | ||||
| def test_doc_api_init(en_vocab): | ||||
|     words = ["a", "b", "c", "d"] | ||||
|  | @ -347,6 +349,7 @@ def test_doc_from_array_morph(en_vocab): | |||
|     assert [str(t.morph) for t in doc] == [str(t.morph) for t in new_doc] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.usefixtures("clean_underscore") | ||||
| def test_doc_api_from_docs(en_tokenizer, de_tokenizer): | ||||
|     en_texts = ["Merging the docs is fun.", "", "They don't think alike."] | ||||
|     en_texts_without_empty = [t for t in en_texts if len(t)] | ||||
|  | @ -355,10 +358,10 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer): | |||
|     en_docs[0].spans["group"] = [en_docs[0][1:4]] | ||||
|     en_docs[2].spans["group"] = [en_docs[2][1:4]] | ||||
|     span_group_texts = sorted([en_docs[0][1:4].text, en_docs[2][1:4].text]) | ||||
|     docs_idx = en_texts[0].index("docs") | ||||
|     de_doc = de_tokenizer(de_text) | ||||
|     expected = (True, None, None, None) | ||||
|     en_docs[0].user_data[("._.", "is_ambiguous", docs_idx, None)] = expected | ||||
|     Token.set_extension("is_ambiguous", default=False) | ||||
|     en_docs[0][2]._.is_ambiguous = True # docs | ||||
|     en_docs[2][3]._.is_ambiguous = True # think | ||||
|     assert Doc.from_docs([]) is None | ||||
|     assert de_doc is not Doc.from_docs([de_doc]) | ||||
|     assert str(de_doc) == str(Doc.from_docs([de_doc])) | ||||
|  | @ -375,11 +378,10 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer): | |||
|     en_docs_tokens = [t for doc in en_docs for t in doc] | ||||
|     assert len(m_doc) == len(en_docs_tokens) | ||||
|     think_idx = len(en_texts[0]) + 1 + en_texts[2].index("think") | ||||
|     assert m_doc[2]._.is_ambiguous == True | ||||
|     assert m_doc[9].idx == think_idx | ||||
|     with pytest.raises(AttributeError): | ||||
|         # not callable, because it was not set via set_extension | ||||
|         m_doc[2]._.is_ambiguous | ||||
|     assert len(m_doc.user_data) == len(en_docs[0].user_data)  # but it's there | ||||
|     assert m_doc[9]._.is_ambiguous == True | ||||
|     assert not any([t._.is_ambiguous for t in m_doc[3:8]]) | ||||
|     assert "group" in m_doc.spans | ||||
|     assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]]) | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,9 +1,11 @@ | |||
| import pytest | ||||
| from spacy.attrs import ORTH, LENGTH | ||||
| from spacy.tokens import Doc, Span | ||||
| from spacy.tokens import Doc, Span, Token | ||||
| from spacy.vocab import Vocab | ||||
| from spacy.util import filter_spans | ||||
| 
 | ||||
| from .test_underscore import clean_underscore  # noqa: F401 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def doc(en_tokenizer): | ||||
|  | @ -219,11 +221,14 @@ def test_span_as_doc(doc): | |||
|     assert span_doc[0].idx == 0 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.usefixtures("clean_underscore") | ||||
| def test_span_as_doc_user_data(doc): | ||||
|     """Test that the user_data can be preserved (but not by default). """ | ||||
|     my_key = "my_info" | ||||
|     my_value = 342 | ||||
|     doc.user_data[my_key] = my_value | ||||
|     Token.set_extension("is_x", default=False) | ||||
|     doc[7]._.is_x = True | ||||
| 
 | ||||
|     span = doc[4:10] | ||||
|     span_doc_with = span.as_doc(copy_user_data=True) | ||||
|  | @ -232,6 +237,12 @@ def test_span_as_doc_user_data(doc): | |||
|     assert doc.user_data.get(my_key, None) is my_value | ||||
|     assert span_doc_with.user_data.get(my_key, None) is my_value | ||||
|     assert span_doc_without.user_data.get(my_key, None) is None | ||||
|     for i in range(len(span_doc_with)): | ||||
|         if i != 3: | ||||
|             assert span_doc_with[i]._.is_x is False | ||||
|         else: | ||||
|             assert span_doc_with[i]._.is_x is True | ||||
|     assert not any([t._.is_x for t in span_doc_without]) | ||||
| 
 | ||||
| 
 | ||||
| def test_span_string_label_kb_id(doc): | ||||
|  |  | |||
|  | @ -1127,7 +1127,7 @@ cdef class Doc: | |||
|             concat_spaces.extend(bool(t.whitespace_) for t in doc) | ||||
| 
 | ||||
|             for key, value in doc.user_data.items(): | ||||
|                 if isinstance(key, tuple) and len(key) == 4: | ||||
|                 if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": | ||||
|                     data_type, name, start, end = key | ||||
|                     if start is not None or end is not None: | ||||
|                         start += char_offset | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ from libc.math cimport sqrt | |||
| import numpy | ||||
| from thinc.api import get_array_module | ||||
| import warnings | ||||
| import copy | ||||
| 
 | ||||
| from .doc cimport token_by_start, token_by_end, get_token_attr, _get_lca_matrix | ||||
| from ..structs cimport TokenC, LexemeC | ||||
|  | @ -241,7 +242,19 @@ cdef class Span: | |||
|                 if cat_start == self.start_char and cat_end == self.end_char: | ||||
|                     doc.cats[cat_label] = value | ||||
|         if copy_user_data: | ||||
|             doc.user_data = self.doc.user_data | ||||
|             user_data = {} | ||||
|             char_offset = self.start_char | ||||
|             for key, value in self.doc.user_data.items(): | ||||
|                 if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": | ||||
|                     data_type, name, start, end = key | ||||
|                     if start is not None or end is not None: | ||||
|                         start -= char_offset | ||||
|                         if end is not None: | ||||
|                             end -= char_offset | ||||
|                         user_data[(data_type, name, start, end)] = copy.copy(value) | ||||
|                 else: | ||||
|                     user_data[key] = copy.copy(value) | ||||
|             doc.user_data = user_data | ||||
|         return doc | ||||
| 
 | ||||
|     def _fix_dep_copy(self, attrs, array): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user