mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix spans weak ref in doc copy (#7225)
* failing unit test * ensure that doc.spans refers to the copied doc, not the old * add type info
This commit is contained in:
parent
9f204b354b
commit
dd99872bb0
|
@ -1,3 +1,5 @@
|
|||
import weakref
|
||||
|
||||
import pytest
|
||||
import numpy
|
||||
import logging
|
||||
|
@ -663,3 +665,10 @@ def test_span_groups(en_tokenizer):
|
|||
assert doc.spans["hi"].has_overlap
|
||||
del doc.spans["hi"]
|
||||
assert "hi" not in doc.spans
|
||||
|
||||
|
||||
def test_doc_spans_copy(en_tokenizer):
|
||||
doc1 = en_tokenizer("Some text about Colombia and the Czech Republic")
|
||||
assert weakref.ref(doc1) == doc1.spans.doc_ref
|
||||
doc2 = doc1.copy()
|
||||
assert weakref.ref(doc2) == doc2.spans.doc_ref
|
||||
|
|
|
@ -33,8 +33,10 @@ class SpanGroups(UserDict):
|
|||
def _make_span_group(self, name: str, spans: Iterable["Span"]) -> SpanGroup:
|
||||
return SpanGroup(self.doc_ref(), name=name, spans=spans)
|
||||
|
||||
def copy(self) -> "SpanGroups":
|
||||
return SpanGroups(self.doc_ref()).from_bytes(self.to_bytes())
|
||||
def copy(self, doc: "Doc" = None) -> "SpanGroups":
|
||||
if doc is None:
|
||||
doc = self.doc_ref()
|
||||
return SpanGroups(doc).from_bytes(self.to_bytes())
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
# We don't need to serialize this as a dict, because the groups
|
||||
|
|
|
@ -1188,7 +1188,7 @@ cdef class Doc:
|
|||
other.user_span_hooks = dict(self.user_span_hooks)
|
||||
other.length = self.length
|
||||
other.max_length = self.max_length
|
||||
other.spans = self.spans.copy()
|
||||
other.spans = self.spans.copy(doc=other)
|
||||
buff_size = other.max_length + (PADDING*2)
|
||||
assert buff_size > 0
|
||||
tokens = <TokenC*>other.mem.alloc(buff_size, sizeof(TokenC))
|
||||
|
|
Loading…
Reference in New Issue
Block a user