mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 00:04:15 +03:00
Span/SpanGroup: wrap SpanC in shared_ptr (#9869)
* Span/SpanGroup: wrap SpanC in shared_ptr When a Span that was retrieved from a SpanGroup was modified, these changes were not reflected in the SpanGroup because the underlying SpanC struct was copied. This change applies the solution proposed by @nrodnova, to wrap SpanC in a shared_ptr. This makes a SpanGroup and Spans derived from it share the same SpanC. So, changes made through a Span are visible in the SpanGroup as well. Fixes #9556 * Test that a SpanGroup is modified through its Spans * SpanGroup.push_back: remove nogil Modifying std::vector is not thread-safe. * C++ >= 11 does not allow const T in vector<T> * Add Span.span_c as a shorthand for Span.c.get Since this method is cdef'ed, it is only visible from Cython, so we avoid using raw pointers in Python Replace existing uses of span.c.get() to use this new method. * Fix formatting * Style fix: pointer types * SpanGroup.to_bytes: reduce number of shared_ptr::get calls * Mark SpanGroup modification test with issue Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
d8a3012539
commit
75f7c15187
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import random
|
||||
from libc.stdint cimport int32_t
|
||||
from libcpp.memory cimport shared_ptr
|
||||
from libcpp.vector cimport vector
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from collections import Counter
|
||||
|
@ -42,9 +44,7 @@ MOVE_NAMES[OUT] = 'O'
|
|||
|
||||
cdef struct GoldNERStateC:
|
||||
Transition* ner
|
||||
SpanC* negs
|
||||
int32_t length
|
||||
int32_t nr_neg
|
||||
vector[shared_ptr[SpanC]] negs
|
||||
|
||||
|
||||
cdef class BiluoGold:
|
||||
|
@ -77,8 +77,6 @@ cdef GoldNERStateC create_gold_state(
|
|||
negs = []
|
||||
assert example.x.length > 0
|
||||
gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition))
|
||||
gs.negs = <SpanC*>mem.alloc(len(negs), sizeof(SpanC))
|
||||
gs.nr_neg = len(negs)
|
||||
ner_ents, ner_tags = example.get_aligned_ents_and_ner()
|
||||
for i, ner_tag in enumerate(ner_tags):
|
||||
gs.ner[i] = moves.lookup_transition(ner_tag)
|
||||
|
@ -92,8 +90,8 @@ cdef GoldNERStateC create_gold_state(
|
|||
# In order to handle negative samples, we need to maintain the full
|
||||
# (start, end, label) triple. If we break it down to the 'isnt B-LOC'
|
||||
# thing, we'll get blocked if there's an incorrect prefix.
|
||||
for i, neg in enumerate(negs):
|
||||
gs.negs[i] = neg.c
|
||||
for neg in negs:
|
||||
gs.negs.push_back(neg.c)
|
||||
return gs
|
||||
|
||||
|
||||
|
@ -410,6 +408,8 @@ cdef class Begin:
|
|||
cdef int g_act = gold.ner[b0].move
|
||||
cdef attr_t g_tag = gold.ner[b0].label
|
||||
|
||||
cdef shared_ptr[SpanC] span
|
||||
|
||||
if g_act == MISSING:
|
||||
pass
|
||||
elif g_act == BEGIN:
|
||||
|
@ -427,8 +427,8 @@ cdef class Begin:
|
|||
# be correct or not. However, we can at least tell whether we're
|
||||
# going to be opening an entity where there's only one possible
|
||||
# L.
|
||||
for span in gold.negs[:gold.nr_neg]:
|
||||
if span.label == label and span.start == b0:
|
||||
for span in gold.negs:
|
||||
if span.get().label == label and span.get().start == b0:
|
||||
cost += 1
|
||||
break
|
||||
return cost
|
||||
|
@ -573,8 +573,9 @@ cdef class Last:
|
|||
# If we have negative-example entities, integrate them into the objective,
|
||||
# by marking actions that close an entity that we know is incorrect
|
||||
# as costly.
|
||||
for span in gold.negs[:gold.nr_neg]:
|
||||
if span.label == label and (span.end-1) == b0 and span.start == ent_start:
|
||||
cdef shared_ptr[SpanC] span
|
||||
for span in gold.negs:
|
||||
if span.get().label == label and (span.get().end-1) == b0 and span.get().start == ent_start:
|
||||
cost += 1
|
||||
break
|
||||
return cost
|
||||
|
@ -638,8 +639,9 @@ cdef class Unit:
|
|||
# This is fairly straight-forward for U- entities, as we have a single
|
||||
# action
|
||||
cdef int b0 = s.B(0)
|
||||
for span in gold.negs[:gold.nr_neg]:
|
||||
if span.label == label and span.start == b0 and span.end == (b0+1):
|
||||
cdef shared_ptr[SpanC] span
|
||||
for span in gold.negs:
|
||||
if span.get().label == label and span.get().start == b0 and span.get().end == (b0+1):
|
||||
cost += 1
|
||||
break
|
||||
return cost
|
||||
|
|
|
@ -4,7 +4,7 @@ from numpy.testing import assert_array_equal
|
|||
|
||||
from spacy.attrs import ORTH, LENGTH
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import Doc, Span, Token
|
||||
from spacy.tokens import Doc, Span, SpanGroup, Token
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.util import filter_spans
|
||||
from thinc.api import get_current_ops
|
||||
|
@ -163,6 +163,18 @@ def test_char_span(doc, i_sent, i, j, text):
|
|||
assert span.text == text
|
||||
|
||||
|
||||
@pytest.mark.issue(9556)
|
||||
def test_modify_span_group(doc):
|
||||
group = SpanGroup(doc, spans=doc.ents)
|
||||
for span in group:
|
||||
span.start = 0
|
||||
span.label = doc.vocab.strings["TEST"]
|
||||
|
||||
# Span changes must be reflected in the span group
|
||||
assert group[0].start == 0
|
||||
assert group[0].label == doc.vocab.strings["TEST"]
|
||||
|
||||
|
||||
def test_spans_sent_spans(doc):
|
||||
sents = list(doc.sents)
|
||||
assert sents[0].start == 0
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from libcpp.memory cimport shared_ptr
|
||||
cimport numpy as np
|
||||
|
||||
from .doc cimport Doc
|
||||
|
@ -7,19 +8,21 @@ from ..structs cimport SpanC
|
|||
|
||||
cdef class Span:
|
||||
cdef readonly Doc doc
|
||||
cdef SpanC c
|
||||
cdef shared_ptr[SpanC] c
|
||||
cdef public _vector
|
||||
cdef public _vector_norm
|
||||
|
||||
@staticmethod
|
||||
cdef inline Span cinit(Doc doc, SpanC span):
|
||||
cdef inline Span cinit(Doc doc, const shared_ptr[SpanC] &span):
|
||||
cdef Span self = Span.__new__(
|
||||
Span,
|
||||
doc,
|
||||
start=span.start,
|
||||
end=span.end
|
||||
start=span.get().start,
|
||||
end=span.get().end
|
||||
)
|
||||
self.c = span
|
||||
return self
|
||||
|
||||
cpdef np.ndarray to_array(self, object features)
|
||||
|
||||
cdef SpanC* span_c(self)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
cimport numpy as np
|
||||
from libc.math cimport sqrt
|
||||
from libcpp.memory cimport make_shared
|
||||
|
||||
import numpy
|
||||
from thinc.api import get_array_module
|
||||
|
@ -109,14 +110,14 @@ cdef class Span:
|
|||
end_char = start_char
|
||||
else:
|
||||
end_char = doc[end - 1].idx + len(doc[end - 1])
|
||||
self.c = SpanC(
|
||||
self.c = make_shared[SpanC](SpanC(
|
||||
label=label,
|
||||
kb_id=kb_id,
|
||||
start=start,
|
||||
end=end,
|
||||
start_char=start_char,
|
||||
end_char=end_char,
|
||||
)
|
||||
))
|
||||
self._vector = vector
|
||||
self._vector_norm = vector_norm
|
||||
|
||||
|
@ -126,41 +127,46 @@ cdef class Span:
|
|||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
cdef SpanC* other_span_c = other.span_c()
|
||||
|
||||
# <
|
||||
if op == 0:
|
||||
return self.c.start_char < other.c.start_char
|
||||
return span_c.start_char < other_span_c.start_char
|
||||
# <=
|
||||
elif op == 1:
|
||||
return self.c.start_char <= other.c.start_char
|
||||
return span_c.start_char <= other_span_c.start_char
|
||||
# ==
|
||||
elif op == 2:
|
||||
# Do the cheap comparisons first
|
||||
return (
|
||||
(self.c.start_char == other.c.start_char) and \
|
||||
(self.c.end_char == other.c.end_char) and \
|
||||
(self.c.label == other.c.label) and \
|
||||
(self.c.kb_id == other.c.kb_id) and \
|
||||
(span_c.start_char == other_span_c.start_char) and \
|
||||
(span_c.end_char == other_span_c.end_char) and \
|
||||
(span_c.label == other_span_c.label) and \
|
||||
(span_c.kb_id == other_span_c.kb_id) and \
|
||||
(self.doc == other.doc)
|
||||
)
|
||||
# !=
|
||||
elif op == 3:
|
||||
# Do the cheap comparisons first
|
||||
return not (
|
||||
(self.c.start_char == other.c.start_char) and \
|
||||
(self.c.end_char == other.c.end_char) and \
|
||||
(self.c.label == other.c.label) and \
|
||||
(self.c.kb_id == other.c.kb_id) and \
|
||||
(span_c.start_char == other_span_c.start_char) and \
|
||||
(span_c.end_char == other_span_c.end_char) and \
|
||||
(span_c.label == other_span_c.label) and \
|
||||
(span_c.kb_id == other_span_c.kb_id) and \
|
||||
(self.doc == other.doc)
|
||||
)
|
||||
# >
|
||||
elif op == 4:
|
||||
return self.c.start_char > other.c.start_char
|
||||
return span_c.start_char > other_span_c.start_char
|
||||
# >=
|
||||
elif op == 5:
|
||||
return self.c.start_char >= other.c.start_char
|
||||
return span_c.start_char >= other_span_c.start_char
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id))
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
return hash((self.doc, span_c.start_char, span_c.end_char, span_c.label, span_c.kb_id))
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of tokens in the span.
|
||||
|
@ -169,9 +175,10 @@ cdef class Span:
|
|||
|
||||
DOCS: https://spacy.io/api/span#len
|
||||
"""
|
||||
if self.c.end < self.c.start:
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
if span_c.end < span_c.start:
|
||||
return 0
|
||||
return self.c.end - self.c.start
|
||||
return span_c.end - span_c.start
|
||||
|
||||
def __repr__(self):
|
||||
return self.text
|
||||
|
@ -185,15 +192,16 @@ cdef class Span:
|
|||
|
||||
DOCS: https://spacy.io/api/span#getitem
|
||||
"""
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
if isinstance(i, slice):
|
||||
start, end = normalize_slice(len(self), i.start, i.stop, i.step)
|
||||
return Span(self.doc, start + self.start, end + self.start)
|
||||
else:
|
||||
if i < 0:
|
||||
token_i = self.c.end + i
|
||||
token_i = span_c.end + i
|
||||
else:
|
||||
token_i = self.c.start + i
|
||||
if self.c.start <= token_i < self.c.end:
|
||||
token_i = span_c.start + i
|
||||
if span_c.start <= token_i < span_c.end:
|
||||
return self.doc[token_i]
|
||||
else:
|
||||
raise IndexError(Errors.E1002)
|
||||
|
@ -205,7 +213,8 @@ cdef class Span:
|
|||
|
||||
DOCS: https://spacy.io/api/span#iter
|
||||
"""
|
||||
for i in range(self.c.start, self.c.end):
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
for i in range(span_c.start, span_c.end):
|
||||
yield self.doc[i]
|
||||
|
||||
def __reduce__(self):
|
||||
|
@ -213,9 +222,10 @@ cdef class Span:
|
|||
|
||||
@property
|
||||
def _(self):
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
"""Custom extension attributes registered via `set_extension`."""
|
||||
return Underscore(Underscore.span_extensions, self,
|
||||
start=self.c.start_char, end=self.c.end_char)
|
||||
start=span_c.start_char, end=span_c.end_char)
|
||||
|
||||
def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None):
|
||||
"""Create a `Doc` object with a copy of the `Span`'s data.
|
||||
|
@ -289,13 +299,14 @@ cdef class Span:
|
|||
cdef int length = len(array)
|
||||
cdef attr_t value
|
||||
cdef int i, head_col, ancestor_i
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
old_to_new_root = dict()
|
||||
if HEAD in attrs:
|
||||
head_col = attrs.index(HEAD)
|
||||
for i in range(length):
|
||||
# if the HEAD refers to a token outside this span, find a more appropriate ancestor
|
||||
token = self[i]
|
||||
ancestor_i = token.head.i - self.c.start # span offset
|
||||
ancestor_i = token.head.i - span_c.start # span offset
|
||||
if ancestor_i not in range(length):
|
||||
if DEP in attrs:
|
||||
array[i, attrs.index(DEP)] = dep
|
||||
|
@ -303,7 +314,7 @@ cdef class Span:
|
|||
# try finding an ancestor within this span
|
||||
ancestors = token.ancestors
|
||||
for ancestor in ancestors:
|
||||
ancestor_i = ancestor.i - self.c.start
|
||||
ancestor_i = ancestor.i - span_c.start
|
||||
if ancestor_i in range(length):
|
||||
array[i, head_col] = ancestor_i - i
|
||||
|
||||
|
@ -332,7 +343,8 @@ cdef class Span:
|
|||
|
||||
DOCS: https://spacy.io/api/span#get_lca_matrix
|
||||
"""
|
||||
return numpy.asarray(_get_lca_matrix(self.doc, self.c.start, self.c.end))
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
return numpy.asarray(_get_lca_matrix(self.doc, span_c.start, span_c.end))
|
||||
|
||||
def similarity(self, other):
|
||||
"""Make a semantic similarity estimate. The default estimate is cosine
|
||||
|
@ -426,6 +438,9 @@ cdef class Span:
|
|||
else:
|
||||
raise ValueError(Errors.E030)
|
||||
|
||||
cdef SpanC* span_c(self):
|
||||
return self.c.get()
|
||||
|
||||
@property
|
||||
def sents(self):
|
||||
"""Obtain the sentences that contain this span. If the given span
|
||||
|
@ -477,10 +492,13 @@ cdef class Span:
|
|||
DOCS: https://spacy.io/api/span#ents
|
||||
"""
|
||||
cdef Span ent
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
cdef SpanC* ent_span_c
|
||||
ents = []
|
||||
for ent in self.doc.ents:
|
||||
if ent.c.start >= self.c.start:
|
||||
if ent.c.end <= self.c.end:
|
||||
ent_span_c = ent.span_c()
|
||||
if ent_span_c.start >= span_c.start:
|
||||
if ent_span_c.end <= span_c.end:
|
||||
ents.append(ent)
|
||||
else:
|
||||
break
|
||||
|
@ -615,11 +633,12 @@ cdef class Span:
|
|||
# This should probably be called 'head', and the other one called
|
||||
# 'gov'. But we went with 'head' elsewhere, and now we're stuck =/
|
||||
cdef int i
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
# First, we scan through the Span, and check whether there's a word
|
||||
# with head==0, i.e. a sentence root. If so, we can return it. The
|
||||
# longer the span, the more likely it contains a sentence root, and
|
||||
# in this case we return in linear time.
|
||||
for i in range(self.c.start, self.c.end):
|
||||
for i in range(span_c.start, span_c.end):
|
||||
if self.doc.c[i].head == 0:
|
||||
return self.doc[i]
|
||||
# If we don't have a sentence root, we do something that's not so
|
||||
|
@ -630,15 +649,15 @@ cdef class Span:
|
|||
# think this should be okay.
|
||||
cdef int current_best = self.doc.length
|
||||
cdef int root = -1
|
||||
for i in range(self.c.start, self.c.end):
|
||||
if self.c.start <= (i+self.doc.c[i].head) < self.c.end:
|
||||
for i in range(span_c.start, span_c.end):
|
||||
if span_c.start <= (i+self.doc.c[i].head) < span_c.end:
|
||||
continue
|
||||
words_to_root = _count_words_to_root(&self.doc.c[i], self.doc.length)
|
||||
if words_to_root < current_best:
|
||||
current_best = words_to_root
|
||||
root = i
|
||||
if root == -1:
|
||||
return self.doc[self.c.start]
|
||||
return self.doc[span_c.start]
|
||||
else:
|
||||
return self.doc[root]
|
||||
|
||||
|
@ -654,8 +673,9 @@ cdef class Span:
|
|||
the span.
|
||||
RETURNS (Span): The newly constructed object.
|
||||
"""
|
||||
start_idx += self.c.start_char
|
||||
end_idx += self.c.start_char
|
||||
cdef SpanC* span_c = self.span_c()
|
||||
start_idx += span_c.start_char
|
||||
end_idx += span_c.start_char
|
||||
return self.doc.char_span(start_idx, end_idx, label=label, kb_id=kb_id, vector=vector)
|
||||
|
||||
@property
|
||||
|
@ -736,53 +756,53 @@ cdef class Span:
|
|||
|
||||
property start:
|
||||
def __get__(self):
|
||||
return self.c.start
|
||||
return self.span_c().start
|
||||
|
||||
def __set__(self, int start):
|
||||
if start < 0:
|
||||
raise IndexError("TODO")
|
||||
self.c.start = start
|
||||
self.span_c().start = start
|
||||
|
||||
property end:
|
||||
def __get__(self):
|
||||
return self.c.end
|
||||
return self.span_c().end
|
||||
|
||||
def __set__(self, int end):
|
||||
if end < 0:
|
||||
raise IndexError("TODO")
|
||||
self.c.end = end
|
||||
self.span_c().end = end
|
||||
|
||||
property start_char:
|
||||
def __get__(self):
|
||||
return self.c.start_char
|
||||
return self.span_c().start_char
|
||||
|
||||
def __set__(self, int start_char):
|
||||
if start_char < 0:
|
||||
raise IndexError("TODO")
|
||||
self.c.start_char = start_char
|
||||
self.span_c().start_char = start_char
|
||||
|
||||
property end_char:
|
||||
def __get__(self):
|
||||
return self.c.end_char
|
||||
return self.span_c().end_char
|
||||
|
||||
def __set__(self, int end_char):
|
||||
if end_char < 0:
|
||||
raise IndexError("TODO")
|
||||
self.c.end_char = end_char
|
||||
self.span_c().end_char = end_char
|
||||
|
||||
property label:
|
||||
def __get__(self):
|
||||
return self.c.label
|
||||
return self.span_c().label
|
||||
|
||||
def __set__(self, attr_t label):
|
||||
self.c.label = label
|
||||
self.span_c().label = label
|
||||
|
||||
property kb_id:
|
||||
def __get__(self):
|
||||
return self.c.kb_id
|
||||
return self.span_c().kb_id
|
||||
|
||||
def __set__(self, attr_t kb_id):
|
||||
self.c.kb_id = kb_id
|
||||
self.span_c().kb_id = kb_id
|
||||
|
||||
property ent_id:
|
||||
"""RETURNS (uint64): The entity ID."""
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from libcpp.memory cimport shared_ptr
|
||||
from libcpp.vector cimport vector
|
||||
from ..structs cimport SpanC
|
||||
|
||||
|
@ -5,6 +6,6 @@ cdef class SpanGroup:
|
|||
cdef public object _doc_ref
|
||||
cdef public str name
|
||||
cdef public dict attrs
|
||||
cdef vector[SpanC] c
|
||||
cdef vector[shared_ptr[SpanC]] c
|
||||
|
||||
cdef void push_back(self, SpanC span) nogil
|
||||
cdef void push_back(self, const shared_ptr[SpanC] &span)
|
||||
|
|
|
@ -5,6 +5,7 @@ import srsly
|
|||
from spacy.errors import Errors
|
||||
from .span cimport Span
|
||||
from libc.stdint cimport uint64_t, uint32_t, int32_t
|
||||
from libcpp.memory cimport make_shared
|
||||
|
||||
|
||||
cdef class SpanGroup:
|
||||
|
@ -135,9 +136,11 @@ cdef class SpanGroup:
|
|||
|
||||
DOCS: https://spacy.io/api/spangroup#to_bytes
|
||||
"""
|
||||
cdef SpanC* span_c
|
||||
output = {"name": self.name, "attrs": self.attrs, "spans": []}
|
||||
for i in range(self.c.size()):
|
||||
span = self.c[i]
|
||||
span_c = span.get()
|
||||
# The struct.pack here is probably overkill, but it might help if
|
||||
# you're saving tonnes of spans, and it doesn't really add any
|
||||
# complexity. We do take care to specify little-endian byte order
|
||||
|
@ -149,13 +152,13 @@ cdef class SpanGroup:
|
|||
# l: int32_t
|
||||
output["spans"].append(struct.pack(
|
||||
">QQQllll",
|
||||
span.id,
|
||||
span.kb_id,
|
||||
span.label,
|
||||
span.start,
|
||||
span.end,
|
||||
span.start_char,
|
||||
span.end_char
|
||||
span_c.id,
|
||||
span_c.kb_id,
|
||||
span_c.label,
|
||||
span_c.start,
|
||||
span_c.end,
|
||||
span_c.start_char,
|
||||
span_c.end_char
|
||||
))
|
||||
return srsly.msgpack_dumps(output)
|
||||
|
||||
|
@ -182,8 +185,8 @@ cdef class SpanGroup:
|
|||
span.end = items[4]
|
||||
span.start_char = items[5]
|
||||
span.end_char = items[6]
|
||||
self.c.push_back(span)
|
||||
self.c.push_back(make_shared[SpanC](span))
|
||||
return self
|
||||
|
||||
cdef void push_back(self, SpanC span) nogil:
|
||||
cdef void push_back(self, const shared_ptr[SpanC] &span):
|
||||
self.c.push_back(span)
|
||||
|
|
Loading…
Reference in New Issue
Block a user