Span/SpanGroup: wrap SpanC in shared_ptr (#9869)

* Span/SpanGroup: wrap SpanC in shared_ptr

When a Span that was retrieved from a SpanGroup was modified, these
changes were not reflected in the SpanGroup because the underlying
SpanC struct was copied.

This change applies the solution proposed by @nrodnova, to wrap SpanC in
a shared_ptr. This makes a SpanGroup and Spans derived from it share the
same SpanC. So, changes made through a Span are visible in the SpanGroup
as well.

Fixes #9556

* Test that a SpanGroup is modified through its Spans

* SpanGroup.push_back: remove nogil

Modifying std::vector is not thread-safe.

* C++ >= 11 does not allow const T in vector<T>

* Add Span.span_c as a shorthand for Span.c.get

Since this method is cdef'ed, it is only visible from Cython, so we
avoid using raw pointers in Python

Replace existing uses of span.c.get() to use this new method.

* Fix formatting

* Style fix: pointer types

* SpanGroup.to_bytes: reduce number of shared_ptr::get calls

* Mark SpanGroup modification test with issue

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Daniël de Kok 2022-01-12 13:38:52 +01:00 committed by GitHub
parent d8a3012539
commit 75f7c15187
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 115 additions and 74 deletions

View File

@ -1,6 +1,8 @@
import os import os
import random import random
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
from libcpp.memory cimport shared_ptr
from libcpp.vector cimport vector
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from collections import Counter from collections import Counter
@ -42,9 +44,7 @@ MOVE_NAMES[OUT] = 'O'
cdef struct GoldNERStateC: cdef struct GoldNERStateC:
Transition* ner Transition* ner
SpanC* negs vector[shared_ptr[SpanC]] negs
int32_t length
int32_t nr_neg
cdef class BiluoGold: cdef class BiluoGold:
@ -77,8 +77,6 @@ cdef GoldNERStateC create_gold_state(
negs = [] negs = []
assert example.x.length > 0 assert example.x.length > 0
gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition)) gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition))
gs.negs = <SpanC*>mem.alloc(len(negs), sizeof(SpanC))
gs.nr_neg = len(negs)
ner_ents, ner_tags = example.get_aligned_ents_and_ner() ner_ents, ner_tags = example.get_aligned_ents_and_ner()
for i, ner_tag in enumerate(ner_tags): for i, ner_tag in enumerate(ner_tags):
gs.ner[i] = moves.lookup_transition(ner_tag) gs.ner[i] = moves.lookup_transition(ner_tag)
@ -92,8 +90,8 @@ cdef GoldNERStateC create_gold_state(
# In order to handle negative samples, we need to maintain the full # In order to handle negative samples, we need to maintain the full
# (start, end, label) triple. If we break it down to the 'isnt B-LOC' # (start, end, label) triple. If we break it down to the 'isnt B-LOC'
# thing, we'll get blocked if there's an incorrect prefix. # thing, we'll get blocked if there's an incorrect prefix.
for i, neg in enumerate(negs): for neg in negs:
gs.negs[i] = neg.c gs.negs.push_back(neg.c)
return gs return gs
@ -410,6 +408,8 @@ cdef class Begin:
cdef int g_act = gold.ner[b0].move cdef int g_act = gold.ner[b0].move
cdef attr_t g_tag = gold.ner[b0].label cdef attr_t g_tag = gold.ner[b0].label
cdef shared_ptr[SpanC] span
if g_act == MISSING: if g_act == MISSING:
pass pass
elif g_act == BEGIN: elif g_act == BEGIN:
@ -427,8 +427,8 @@ cdef class Begin:
# be correct or not. However, we can at least tell whether we're # be correct or not. However, we can at least tell whether we're
# going to be opening an entity where there's only one possible # going to be opening an entity where there's only one possible
# L. # L.
for span in gold.negs[:gold.nr_neg]: for span in gold.negs:
if span.label == label and span.start == b0: if span.get().label == label and span.get().start == b0:
cost += 1 cost += 1
break break
return cost return cost
@ -573,8 +573,9 @@ cdef class Last:
# If we have negative-example entities, integrate them into the objective, # If we have negative-example entities, integrate them into the objective,
# by marking actions that close an entity that we know is incorrect # by marking actions that close an entity that we know is incorrect
# as costly. # as costly.
for span in gold.negs[:gold.nr_neg]: cdef shared_ptr[SpanC] span
if span.label == label and (span.end-1) == b0 and span.start == ent_start: for span in gold.negs:
if span.get().label == label and (span.get().end-1) == b0 and span.get().start == ent_start:
cost += 1 cost += 1
break break
return cost return cost
@ -638,8 +639,9 @@ cdef class Unit:
# This is fairly straight-forward for U- entities, as we have a single # This is fairly straight-forward for U- entities, as we have a single
# action # action
cdef int b0 = s.B(0) cdef int b0 = s.B(0)
for span in gold.negs[:gold.nr_neg]: cdef shared_ptr[SpanC] span
if span.label == label and span.start == b0 and span.end == (b0+1): for span in gold.negs:
if span.get().label == label and span.get().start == b0 and span.get().end == (b0+1):
cost += 1 cost += 1
break break
return cost return cost

View File

@ -4,7 +4,7 @@ from numpy.testing import assert_array_equal
from spacy.attrs import ORTH, LENGTH from spacy.attrs import ORTH, LENGTH
from spacy.lang.en import English from spacy.lang.en import English
from spacy.tokens import Doc, Span, Token from spacy.tokens import Doc, Span, SpanGroup, Token
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.util import filter_spans from spacy.util import filter_spans
from thinc.api import get_current_ops from thinc.api import get_current_ops
@ -163,6 +163,18 @@ def test_char_span(doc, i_sent, i, j, text):
assert span.text == text assert span.text == text
@pytest.mark.issue(9556)
def test_modify_span_group(doc):
group = SpanGroup(doc, spans=doc.ents)
for span in group:
span.start = 0
span.label = doc.vocab.strings["TEST"]
# Span changes must be reflected in the span group
assert group[0].start == 0
assert group[0].label == doc.vocab.strings["TEST"]
def test_spans_sent_spans(doc): def test_spans_sent_spans(doc):
sents = list(doc.sents) sents = list(doc.sents)
assert sents[0].start == 0 assert sents[0].start == 0

View File

@ -1,3 +1,4 @@
from libcpp.memory cimport shared_ptr
cimport numpy as np cimport numpy as np
from .doc cimport Doc from .doc cimport Doc
@ -7,19 +8,21 @@ from ..structs cimport SpanC
cdef class Span: cdef class Span:
cdef readonly Doc doc cdef readonly Doc doc
cdef SpanC c cdef shared_ptr[SpanC] c
cdef public _vector cdef public _vector
cdef public _vector_norm cdef public _vector_norm
@staticmethod @staticmethod
cdef inline Span cinit(Doc doc, SpanC span): cdef inline Span cinit(Doc doc, const shared_ptr[SpanC] &span):
cdef Span self = Span.__new__( cdef Span self = Span.__new__(
Span, Span,
doc, doc,
start=span.start, start=span.get().start,
end=span.end end=span.get().end
) )
self.c = span self.c = span
return self return self
cpdef np.ndarray to_array(self, object features) cpdef np.ndarray to_array(self, object features)
cdef SpanC* span_c(self)

View File

@ -1,5 +1,6 @@
cimport numpy as np cimport numpy as np
from libc.math cimport sqrt from libc.math cimport sqrt
from libcpp.memory cimport make_shared
import numpy import numpy
from thinc.api import get_array_module from thinc.api import get_array_module
@ -109,14 +110,14 @@ cdef class Span:
end_char = start_char end_char = start_char
else: else:
end_char = doc[end - 1].idx + len(doc[end - 1]) end_char = doc[end - 1].idx + len(doc[end - 1])
self.c = SpanC( self.c = make_shared[SpanC](SpanC(
label=label, label=label,
kb_id=kb_id, kb_id=kb_id,
start=start, start=start,
end=end, end=end,
start_char=start_char, start_char=start_char,
end_char=end_char, end_char=end_char,
) ))
self._vector = vector self._vector = vector
self._vector_norm = vector_norm self._vector_norm = vector_norm
@ -126,41 +127,46 @@ cdef class Span:
return False return False
else: else:
return True return True
cdef SpanC* span_c = self.span_c()
cdef SpanC* other_span_c = other.span_c()
# < # <
if op == 0: if op == 0:
return self.c.start_char < other.c.start_char return span_c.start_char < other_span_c.start_char
# <= # <=
elif op == 1: elif op == 1:
return self.c.start_char <= other.c.start_char return span_c.start_char <= other_span_c.start_char
# == # ==
elif op == 2: elif op == 2:
# Do the cheap comparisons first # Do the cheap comparisons first
return ( return (
(self.c.start_char == other.c.start_char) and \ (span_c.start_char == other_span_c.start_char) and \
(self.c.end_char == other.c.end_char) and \ (span_c.end_char == other_span_c.end_char) and \
(self.c.label == other.c.label) and \ (span_c.label == other_span_c.label) and \
(self.c.kb_id == other.c.kb_id) and \ (span_c.kb_id == other_span_c.kb_id) and \
(self.doc == other.doc) (self.doc == other.doc)
) )
# != # !=
elif op == 3: elif op == 3:
# Do the cheap comparisons first # Do the cheap comparisons first
return not ( return not (
(self.c.start_char == other.c.start_char) and \ (span_c.start_char == other_span_c.start_char) and \
(self.c.end_char == other.c.end_char) and \ (span_c.end_char == other_span_c.end_char) and \
(self.c.label == other.c.label) and \ (span_c.label == other_span_c.label) and \
(self.c.kb_id == other.c.kb_id) and \ (span_c.kb_id == other_span_c.kb_id) and \
(self.doc == other.doc) (self.doc == other.doc)
) )
# > # >
elif op == 4: elif op == 4:
return self.c.start_char > other.c.start_char return span_c.start_char > other_span_c.start_char
# >= # >=
elif op == 5: elif op == 5:
return self.c.start_char >= other.c.start_char return span_c.start_char >= other_span_c.start_char
def __hash__(self): def __hash__(self):
return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id)) cdef SpanC* span_c = self.span_c()
return hash((self.doc, span_c.start_char, span_c.end_char, span_c.label, span_c.kb_id))
def __len__(self): def __len__(self):
"""Get the number of tokens in the span. """Get the number of tokens in the span.
@ -169,9 +175,10 @@ cdef class Span:
DOCS: https://spacy.io/api/span#len DOCS: https://spacy.io/api/span#len
""" """
if self.c.end < self.c.start: cdef SpanC* span_c = self.span_c()
if span_c.end < span_c.start:
return 0 return 0
return self.c.end - self.c.start return span_c.end - span_c.start
def __repr__(self): def __repr__(self):
return self.text return self.text
@ -185,15 +192,16 @@ cdef class Span:
DOCS: https://spacy.io/api/span#getitem DOCS: https://spacy.io/api/span#getitem
""" """
cdef SpanC* span_c = self.span_c()
if isinstance(i, slice): if isinstance(i, slice):
start, end = normalize_slice(len(self), i.start, i.stop, i.step) start, end = normalize_slice(len(self), i.start, i.stop, i.step)
return Span(self.doc, start + self.start, end + self.start) return Span(self.doc, start + self.start, end + self.start)
else: else:
if i < 0: if i < 0:
token_i = self.c.end + i token_i = span_c.end + i
else: else:
token_i = self.c.start + i token_i = span_c.start + i
if self.c.start <= token_i < self.c.end: if span_c.start <= token_i < span_c.end:
return self.doc[token_i] return self.doc[token_i]
else: else:
raise IndexError(Errors.E1002) raise IndexError(Errors.E1002)
@ -205,7 +213,8 @@ cdef class Span:
DOCS: https://spacy.io/api/span#iter DOCS: https://spacy.io/api/span#iter
""" """
for i in range(self.c.start, self.c.end): cdef SpanC* span_c = self.span_c()
for i in range(span_c.start, span_c.end):
yield self.doc[i] yield self.doc[i]
def __reduce__(self): def __reduce__(self):
@ -213,9 +222,10 @@ cdef class Span:
@property @property
def _(self): def _(self):
cdef SpanC* span_c = self.span_c()
"""Custom extension attributes registered via `set_extension`.""" """Custom extension attributes registered via `set_extension`."""
return Underscore(Underscore.span_extensions, self, return Underscore(Underscore.span_extensions, self,
start=self.c.start_char, end=self.c.end_char) start=span_c.start_char, end=span_c.end_char)
def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None): def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None):
"""Create a `Doc` object with a copy of the `Span`'s data. """Create a `Doc` object with a copy of the `Span`'s data.
@ -289,13 +299,14 @@ cdef class Span:
cdef int length = len(array) cdef int length = len(array)
cdef attr_t value cdef attr_t value
cdef int i, head_col, ancestor_i cdef int i, head_col, ancestor_i
cdef SpanC* span_c = self.span_c()
old_to_new_root = dict() old_to_new_root = dict()
if HEAD in attrs: if HEAD in attrs:
head_col = attrs.index(HEAD) head_col = attrs.index(HEAD)
for i in range(length): for i in range(length):
# if the HEAD refers to a token outside this span, find a more appropriate ancestor # if the HEAD refers to a token outside this span, find a more appropriate ancestor
token = self[i] token = self[i]
ancestor_i = token.head.i - self.c.start # span offset ancestor_i = token.head.i - span_c.start # span offset
if ancestor_i not in range(length): if ancestor_i not in range(length):
if DEP in attrs: if DEP in attrs:
array[i, attrs.index(DEP)] = dep array[i, attrs.index(DEP)] = dep
@ -303,7 +314,7 @@ cdef class Span:
# try finding an ancestor within this span # try finding an ancestor within this span
ancestors = token.ancestors ancestors = token.ancestors
for ancestor in ancestors: for ancestor in ancestors:
ancestor_i = ancestor.i - self.c.start ancestor_i = ancestor.i - span_c.start
if ancestor_i in range(length): if ancestor_i in range(length):
array[i, head_col] = ancestor_i - i array[i, head_col] = ancestor_i - i
@ -332,7 +343,8 @@ cdef class Span:
DOCS: https://spacy.io/api/span#get_lca_matrix DOCS: https://spacy.io/api/span#get_lca_matrix
""" """
return numpy.asarray(_get_lca_matrix(self.doc, self.c.start, self.c.end)) cdef SpanC* span_c = self.span_c()
return numpy.asarray(_get_lca_matrix(self.doc, span_c.start, span_c.end))
def similarity(self, other): def similarity(self, other):
"""Make a semantic similarity estimate. The default estimate is cosine """Make a semantic similarity estimate. The default estimate is cosine
@ -426,6 +438,9 @@ cdef class Span:
else: else:
raise ValueError(Errors.E030) raise ValueError(Errors.E030)
cdef SpanC* span_c(self):
return self.c.get()
@property @property
def sents(self): def sents(self):
"""Obtain the sentences that contain this span. If the given span """Obtain the sentences that contain this span. If the given span
@ -477,10 +492,13 @@ cdef class Span:
DOCS: https://spacy.io/api/span#ents DOCS: https://spacy.io/api/span#ents
""" """
cdef Span ent cdef Span ent
cdef SpanC* span_c = self.span_c()
cdef SpanC* ent_span_c
ents = [] ents = []
for ent in self.doc.ents: for ent in self.doc.ents:
if ent.c.start >= self.c.start: ent_span_c = ent.span_c()
if ent.c.end <= self.c.end: if ent_span_c.start >= span_c.start:
if ent_span_c.end <= span_c.end:
ents.append(ent) ents.append(ent)
else: else:
break break
@ -615,11 +633,12 @@ cdef class Span:
# This should probably be called 'head', and the other one called # This should probably be called 'head', and the other one called
# 'gov'. But we went with 'head' elsewhere, and now we're stuck =/ # 'gov'. But we went with 'head' elsewhere, and now we're stuck =/
cdef int i cdef int i
cdef SpanC* span_c = self.span_c()
# First, we scan through the Span, and check whether there's a word # First, we scan through the Span, and check whether there's a word
# with head==0, i.e. a sentence root. If so, we can return it. The # with head==0, i.e. a sentence root. If so, we can return it. The
# longer the span, the more likely it contains a sentence root, and # longer the span, the more likely it contains a sentence root, and
# in this case we return in linear time. # in this case we return in linear time.
for i in range(self.c.start, self.c.end): for i in range(span_c.start, span_c.end):
if self.doc.c[i].head == 0: if self.doc.c[i].head == 0:
return self.doc[i] return self.doc[i]
# If we don't have a sentence root, we do something that's not so # If we don't have a sentence root, we do something that's not so
@ -630,15 +649,15 @@ cdef class Span:
# think this should be okay. # think this should be okay.
cdef int current_best = self.doc.length cdef int current_best = self.doc.length
cdef int root = -1 cdef int root = -1
for i in range(self.c.start, self.c.end): for i in range(span_c.start, span_c.end):
if self.c.start <= (i+self.doc.c[i].head) < self.c.end: if span_c.start <= (i+self.doc.c[i].head) < span_c.end:
continue continue
words_to_root = _count_words_to_root(&self.doc.c[i], self.doc.length) words_to_root = _count_words_to_root(&self.doc.c[i], self.doc.length)
if words_to_root < current_best: if words_to_root < current_best:
current_best = words_to_root current_best = words_to_root
root = i root = i
if root == -1: if root == -1:
return self.doc[self.c.start] return self.doc[span_c.start]
else: else:
return self.doc[root] return self.doc[root]
@ -654,8 +673,9 @@ cdef class Span:
the span. the span.
RETURNS (Span): The newly constructed object. RETURNS (Span): The newly constructed object.
""" """
start_idx += self.c.start_char cdef SpanC* span_c = self.span_c()
end_idx += self.c.start_char start_idx += span_c.start_char
end_idx += span_c.start_char
return self.doc.char_span(start_idx, end_idx, label=label, kb_id=kb_id, vector=vector) return self.doc.char_span(start_idx, end_idx, label=label, kb_id=kb_id, vector=vector)
@property @property
@ -736,53 +756,53 @@ cdef class Span:
property start: property start:
def __get__(self): def __get__(self):
return self.c.start return self.span_c().start
def __set__(self, int start): def __set__(self, int start):
if start < 0: if start < 0:
raise IndexError("TODO") raise IndexError("TODO")
self.c.start = start self.span_c().start = start
property end: property end:
def __get__(self): def __get__(self):
return self.c.end return self.span_c().end
def __set__(self, int end): def __set__(self, int end):
if end < 0: if end < 0:
raise IndexError("TODO") raise IndexError("TODO")
self.c.end = end self.span_c().end = end
property start_char: property start_char:
def __get__(self): def __get__(self):
return self.c.start_char return self.span_c().start_char
def __set__(self, int start_char): def __set__(self, int start_char):
if start_char < 0: if start_char < 0:
raise IndexError("TODO") raise IndexError("TODO")
self.c.start_char = start_char self.span_c().start_char = start_char
property end_char: property end_char:
def __get__(self): def __get__(self):
return self.c.end_char return self.span_c().end_char
def __set__(self, int end_char): def __set__(self, int end_char):
if end_char < 0: if end_char < 0:
raise IndexError("TODO") raise IndexError("TODO")
self.c.end_char = end_char self.span_c().end_char = end_char
property label: property label:
def __get__(self): def __get__(self):
return self.c.label return self.span_c().label
def __set__(self, attr_t label): def __set__(self, attr_t label):
self.c.label = label self.span_c().label = label
property kb_id: property kb_id:
def __get__(self): def __get__(self):
return self.c.kb_id return self.span_c().kb_id
def __set__(self, attr_t kb_id): def __set__(self, attr_t kb_id):
self.c.kb_id = kb_id self.span_c().kb_id = kb_id
property ent_id: property ent_id:
"""RETURNS (uint64): The entity ID.""" """RETURNS (uint64): The entity ID."""

View File

@ -1,3 +1,4 @@
from libcpp.memory cimport shared_ptr
from libcpp.vector cimport vector from libcpp.vector cimport vector
from ..structs cimport SpanC from ..structs cimport SpanC
@ -5,6 +6,6 @@ cdef class SpanGroup:
cdef public object _doc_ref cdef public object _doc_ref
cdef public str name cdef public str name
cdef public dict attrs cdef public dict attrs
cdef vector[SpanC] c cdef vector[shared_ptr[SpanC]] c
cdef void push_back(self, SpanC span) nogil cdef void push_back(self, const shared_ptr[SpanC] &span)

View File

@ -5,6 +5,7 @@ import srsly
from spacy.errors import Errors from spacy.errors import Errors
from .span cimport Span from .span cimport Span
from libc.stdint cimport uint64_t, uint32_t, int32_t from libc.stdint cimport uint64_t, uint32_t, int32_t
from libcpp.memory cimport make_shared
cdef class SpanGroup: cdef class SpanGroup:
@ -135,9 +136,11 @@ cdef class SpanGroup:
DOCS: https://spacy.io/api/spangroup#to_bytes DOCS: https://spacy.io/api/spangroup#to_bytes
""" """
cdef SpanC* span_c
output = {"name": self.name, "attrs": self.attrs, "spans": []} output = {"name": self.name, "attrs": self.attrs, "spans": []}
for i in range(self.c.size()): for i in range(self.c.size()):
span = self.c[i] span = self.c[i]
span_c = span.get()
# The struct.pack here is probably overkill, but it might help if # The struct.pack here is probably overkill, but it might help if
# you're saving tonnes of spans, and it doesn't really add any # you're saving tonnes of spans, and it doesn't really add any
# complexity. We do take care to specify little-endian byte order # complexity. We do take care to specify little-endian byte order
@ -149,13 +152,13 @@ cdef class SpanGroup:
# l: int32_t # l: int32_t
output["spans"].append(struct.pack( output["spans"].append(struct.pack(
">QQQllll", ">QQQllll",
span.id, span_c.id,
span.kb_id, span_c.kb_id,
span.label, span_c.label,
span.start, span_c.start,
span.end, span_c.end,
span.start_char, span_c.start_char,
span.end_char span_c.end_char
)) ))
return srsly.msgpack_dumps(output) return srsly.msgpack_dumps(output)
@ -182,8 +185,8 @@ cdef class SpanGroup:
span.end = items[4] span.end = items[4]
span.start_char = items[5] span.start_char = items[5]
span.end_char = items[6] span.end_char = items[6]
self.c.push_back(span) self.c.push_back(make_shared[SpanC](span))
return self return self
cdef void push_back(self, SpanC span) nogil: cdef void push_back(self, const shared_ptr[SpanC] &span):
self.c.push_back(span) self.c.push_back(span)