Fix spans weak ref in doc copy (#7225)

* failing unit test

* ensure that doc.spans refers to the copied doc, not the old

* add type info
This commit is contained in:
Sofie Van Landeghem 2021-02-28 02:32:48 +01:00 committed by GitHub
parent 9f204b354b
commit dd99872bb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 3 deletions

View File

@ -1,3 +1,5 @@
import weakref
import pytest import pytest
import numpy import numpy
import logging import logging
@ -663,3 +665,10 @@ def test_span_groups(en_tokenizer):
assert doc.spans["hi"].has_overlap assert doc.spans["hi"].has_overlap
del doc.spans["hi"] del doc.spans["hi"]
assert "hi" not in doc.spans assert "hi" not in doc.spans
def test_doc_spans_copy(en_tokenizer):
doc1 = en_tokenizer("Some text about Colombia and the Czech Republic")
assert weakref.ref(doc1) == doc1.spans.doc_ref
doc2 = doc1.copy()
assert weakref.ref(doc2) == doc2.spans.doc_ref

View File

@ -33,8 +33,10 @@ class SpanGroups(UserDict):
def _make_span_group(self, name: str, spans: Iterable["Span"]) -> SpanGroup: def _make_span_group(self, name: str, spans: Iterable["Span"]) -> SpanGroup:
return SpanGroup(self.doc_ref(), name=name, spans=spans) return SpanGroup(self.doc_ref(), name=name, spans=spans)
def copy(self) -> "SpanGroups": def copy(self, doc: "Doc" = None) -> "SpanGroups":
return SpanGroups(self.doc_ref()).from_bytes(self.to_bytes()) if doc is None:
doc = self.doc_ref()
return SpanGroups(doc).from_bytes(self.to_bytes())
def to_bytes(self) -> bytes: def to_bytes(self) -> bytes:
# We don't need to serialize this as a dict, because the groups # We don't need to serialize this as a dict, because the groups

View File

@ -1188,7 +1188,7 @@ cdef class Doc:
other.user_span_hooks = dict(self.user_span_hooks) other.user_span_hooks = dict(self.user_span_hooks)
other.length = self.length other.length = self.length
other.max_length = self.max_length other.max_length = self.max_length
other.spans = self.spans.copy() other.spans = self.spans.copy(doc=other)
buff_size = other.max_length + (PADDING*2) buff_size = other.max_length + (PADDING*2)
assert buff_size > 0 assert buff_size > 0
tokens = <TokenC*>other.mem.alloc(buff_size, sizeof(TokenC)) tokens = <TokenC*>other.mem.alloc(buff_size, sizeof(TokenC))