diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index 981c5cdd2..28f6c16c4 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -256,6 +256,10 @@ cdef class Matcher: # non-overlapping ones this `match` can be either (start, end) or # (start, end, alignments) depending on `with_alignments=` option. for key, *match in matches: + # Adjust span matches to doc offsets + if isinstance(doclike, Span): + match[0] += doclike.start + match[1] += doclike.start span_filter = self._filter.get(key) if span_filter is not None: pairs = pairs_by_id.get(key, []) @@ -286,9 +290,6 @@ cdef class Matcher: if as_spans: final_results = [] for key, start, end, *_ in final_matches: - if isinstance(doclike, Span): - start += doclike.start - end += doclike.start final_results.append(Span(doc, start, end, label=key)) elif with_alignments: # convert alignments List[Dict[str, int]] --> List[int] diff --git a/spacy/pipeline/_parser_internals/ner.pyx b/spacy/pipeline/_parser_internals/ner.pyx index fab872f00..cc196d85a 100644 --- a/spacy/pipeline/_parser_internals/ner.pyx +++ b/spacy/pipeline/_parser_internals/ner.pyx @@ -1,6 +1,8 @@ import os import random from libc.stdint cimport int32_t +from libcpp.memory cimport shared_ptr +from libcpp.vector cimport vector from cymem.cymem cimport Pool from collections import Counter @@ -43,9 +45,7 @@ MOVE_NAMES[OUT] = 'O' cdef struct GoldNERStateC: Transition* ner - SpanC* negs - int32_t length - int32_t nr_neg + vector[shared_ptr[SpanC]] negs cdef class BiluoGold: @@ -78,8 +78,6 @@ cdef GoldNERStateC create_gold_state( negs = [] assert example.x.length > 0 gs.ner = mem.alloc(example.x.length, sizeof(Transition)) - gs.negs = mem.alloc(len(negs), sizeof(SpanC)) - gs.nr_neg = len(negs) ner_ents, ner_tags = example.get_aligned_ents_and_ner() for i, ner_tag in enumerate(ner_tags): gs.ner[i] = moves.lookup_transition(ner_tag) @@ -93,8 +91,8 @@ cdef GoldNERStateC create_gold_state( # In order to handle negative samples, we need to maintain the full # (start, end, label) triple. If we break it down to the 'isnt B-LOC' # thing, we'll get blocked if there's an incorrect prefix. - for i, neg in enumerate(negs): - gs.negs[i] = neg.c + for neg in negs: + gs.negs.push_back(neg.c) return gs @@ -411,6 +409,8 @@ cdef class Begin: cdef int g_act = gold.ner[b0].move cdef attr_t g_tag = gold.ner[b0].label + cdef shared_ptr[SpanC] span + if g_act == MISSING: pass elif g_act == BEGIN: @@ -428,8 +428,8 @@ cdef class Begin: # be correct or not. However, we can at least tell whether we're # going to be opening an entity where there's only one possible # L. - for span in gold.negs[:gold.nr_neg]: - if span.label == label and span.start == b0: + for span in gold.negs: + if span.get().label == label and span.get().start == b0: cost += 1 break return cost @@ -574,8 +574,9 @@ cdef class Last: # If we have negative-example entities, integrate them into the objective, # by marking actions that close an entity that we know is incorrect # as costly. - for span in gold.negs[:gold.nr_neg]: - if span.label == label and (span.end-1) == b0 and span.start == ent_start: + cdef shared_ptr[SpanC] span + for span in gold.negs: + if span.get().label == label and (span.get().end-1) == b0 and span.get().start == ent_start: cost += 1 break return cost @@ -639,8 +640,9 @@ cdef class Unit: # This is fairly straight-forward for U- entities, as we have a single # action cdef int b0 = s.B(0) - for span in gold.negs[:gold.nr_neg]: - if span.label == label and span.start == b0 and span.end == (b0+1): + cdef shared_ptr[SpanC] span + for span in gold.negs: + if span.get().label == label and span.get().start == b0 and span.get().end == (b0+1): cost += 1 break return cost diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index 3676b35af..c6303c52d 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -4,7 +4,7 @@ from numpy.testing import assert_array_equal from spacy.attrs import ORTH, LENGTH from spacy.lang.en import English -from spacy.tokens import Doc, Span, Token +from spacy.tokens import Doc, Span, SpanGroup, Token from spacy.vocab import Vocab from spacy.util import filter_spans from thinc.api import get_current_ops @@ -163,6 +163,18 @@ def test_char_span(doc, i_sent, i, j, text): assert span.text == text +@pytest.mark.issue(9556) +def test_modify_span_group(doc): + group = SpanGroup(doc, spans=doc.ents) + for span in group: + span.start = 0 + span.label = doc.vocab.strings["TEST"] + + # Span changes must be reflected in the span group + assert group[0].start == 0 + assert group[0].label == doc.vocab.strings["TEST"] + + def test_spans_sent_spans(doc): sents = list(doc.sents) assert sents[0].start == 0 diff --git a/spacy/tests/doc/test_span_group.py b/spacy/tests/doc/test_span_group.py index 8c70a83e1..da3c24908 100644 --- a/spacy/tests/doc/test_span_group.py +++ b/spacy/tests/doc/test_span_group.py @@ -99,8 +99,10 @@ def test_span_group_set_item(doc, other_doc): span.label_ = "NEW LABEL" span.kb_id = doc.vocab.strings["KB_ID"] - assert span_group[index].label != span.label - assert span_group[index].kb_id != span.kb_id + # Indexing a span group returns a span in which C + # data is shared. + assert span_group[index].label == span.label + assert span_group[index].kb_id == span.kb_id span_group[index] = span assert span_group[index].start == span.start diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index e8c3d53e8..dc698ae30 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -602,9 +602,16 @@ def test_matcher_span(matcher): doc = Doc(matcher.vocab, words=text.split()) span_js = doc[:3] span_java = doc[4:] - assert len(matcher(doc)) == 2 - assert len(matcher(span_js)) == 1 - assert len(matcher(span_java)) == 1 + doc_matches = matcher(doc) + span_js_matches = matcher(span_js) + span_java_matches = matcher(span_java) + assert len(doc_matches) == 2 + assert len(span_js_matches) == 1 + assert len(span_java_matches) == 1 + + # match offsets always refer to the doc + assert doc_matches[0] == span_js_matches[0] + assert doc_matches[1] == span_java_matches[0] def test_matcher_as_spans(matcher): diff --git a/spacy/tokens/span.pxd b/spacy/tokens/span.pxd index 78bee0a8c..85553068e 100644 --- a/spacy/tokens/span.pxd +++ b/spacy/tokens/span.pxd @@ -1,3 +1,4 @@ +from libcpp.memory cimport shared_ptr cimport numpy as np from .doc cimport Doc @@ -7,19 +8,21 @@ from ..structs cimport SpanC cdef class Span: cdef readonly Doc doc - cdef SpanC c + cdef shared_ptr[SpanC] c cdef public _vector cdef public _vector_norm @staticmethod - cdef inline Span cinit(Doc doc, SpanC span): + cdef inline Span cinit(Doc doc, const shared_ptr[SpanC] &span): cdef Span self = Span.__new__( Span, doc, - start=span.start, - end=span.end + start=span.get().start, + end=span.get().end ) self.c = span return self cpdef np.ndarray to_array(self, object features) + + cdef SpanC* span_c(self) diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index ab888ae95..b66a7ada3 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -1,5 +1,6 @@ cimport numpy as np from libc.math cimport sqrt +from libcpp.memory cimport make_shared import numpy from thinc.api import get_array_module @@ -114,7 +115,7 @@ cdef class Span: end_char = start_char else: end_char = doc[end - 1].idx + len(doc[end - 1]) - self.c = SpanC( + self.c = make_shared[SpanC](SpanC( label=label, kb_id=kb_id, id=span_id, @@ -122,7 +123,7 @@ cdef class Span: end=end, start_char=start_char, end_char=end_char, - ) + )) self._vector = vector self._vector_norm = vector_norm @@ -132,8 +133,11 @@ cdef class Span: return False else: return True - self_tuple = (self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id, self.id, self.doc) - other_tuple = (other.c.start_char, other.c.end_char, other.c.label, other.c.kb_id, other.id, other.doc) + + cdef SpanC* span_c = self.span_c() + cdef SpanC* other_span_c = other.span_c() + self_tuple = (span_c.start_char, span_c.end_char, span_c.label, span_c.kb_id, self.id, self.doc) + other_tuple = (other_span_c.start_char, other_span_c.end_char, other_span_c.label, other_span_c.kb_id, other.id, other.doc) # < if op == 0: return self_tuple < other_tuple @@ -154,7 +158,8 @@ cdef class Span: return self_tuple >= other_tuple def __hash__(self): - return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id, self.c.id)) + cdef SpanC* span_c = self.span_c() + return hash((self.doc, span_c.start_char, span_c.end_char, span_c.label, span_c.kb_id, span_c.id)) def __len__(self): """Get the number of tokens in the span. @@ -163,9 +168,10 @@ cdef class Span: DOCS: https://spacy.io/api/span#len """ - if self.c.end < self.c.start: + cdef SpanC* span_c = self.span_c() + if span_c.end < span_c.start: return 0 - return self.c.end - self.c.start + return span_c.end - span_c.start def __repr__(self): return self.text @@ -179,15 +185,16 @@ cdef class Span: DOCS: https://spacy.io/api/span#getitem """ + cdef SpanC* span_c = self.span_c() if isinstance(i, slice): start, end = normalize_slice(len(self), i.start, i.stop, i.step) return Span(self.doc, start + self.start, end + self.start) else: if i < 0: - token_i = self.c.end + i + token_i = span_c.end + i else: - token_i = self.c.start + i - if self.c.start <= token_i < self.c.end: + token_i = span_c.start + i + if span_c.start <= token_i < span_c.end: return self.doc[token_i] else: raise IndexError(Errors.E1002) @@ -199,7 +206,8 @@ cdef class Span: DOCS: https://spacy.io/api/span#iter """ - for i in range(self.c.start, self.c.end): + cdef SpanC* span_c = self.span_c() + for i in range(span_c.start, span_c.end): yield self.doc[i] def __reduce__(self): @@ -207,9 +215,10 @@ cdef class Span: @property def _(self): + cdef SpanC* span_c = self.span_c() """Custom extension attributes registered via `set_extension`.""" return Underscore(Underscore.span_extensions, self, - start=self.c.start_char, end=self.c.end_char) + start=span_c.start_char, end=span_c.end_char) def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None): """Create a `Doc` object with a copy of the `Span`'s data. @@ -283,13 +292,14 @@ cdef class Span: cdef int length = len(array) cdef attr_t value cdef int i, head_col, ancestor_i + cdef SpanC* span_c = self.span_c() old_to_new_root = dict() if HEAD in attrs: head_col = attrs.index(HEAD) for i in range(length): # if the HEAD refers to a token outside this span, find a more appropriate ancestor token = self[i] - ancestor_i = token.head.i - self.c.start # span offset + ancestor_i = token.head.i - span_c.start # span offset if ancestor_i not in range(length): if DEP in attrs: array[i, attrs.index(DEP)] = dep @@ -297,7 +307,7 @@ cdef class Span: # try finding an ancestor within this span ancestors = token.ancestors for ancestor in ancestors: - ancestor_i = ancestor.i - self.c.start + ancestor_i = ancestor.i - span_c.start if ancestor_i in range(length): array[i, head_col] = ancestor_i - i @@ -326,7 +336,8 @@ cdef class Span: DOCS: https://spacy.io/api/span#get_lca_matrix """ - return numpy.asarray(_get_lca_matrix(self.doc, self.c.start, self.c.end)) + cdef SpanC* span_c = self.span_c() + return numpy.asarray(_get_lca_matrix(self.doc, span_c.start, span_c.end)) def similarity(self, other): """Make a semantic similarity estimate. The default estimate is cosine @@ -422,6 +433,9 @@ cdef class Span: else: raise ValueError(Errors.E030) + cdef SpanC* span_c(self): + return self.c.get() + @property def sents(self): """Obtain the sentences that contain this span. If the given span @@ -473,10 +487,13 @@ cdef class Span: DOCS: https://spacy.io/api/span#ents """ cdef Span ent + cdef SpanC* span_c = self.span_c() + cdef SpanC* ent_span_c ents = [] for ent in self.doc.ents: - if ent.c.start >= self.c.start: - if ent.c.end <= self.c.end: + ent_span_c = ent.span_c() + if ent_span_c.start >= span_c.start: + if ent_span_c.end <= span_c.end: ents.append(ent) else: break @@ -611,11 +628,12 @@ cdef class Span: # This should probably be called 'head', and the other one called # 'gov'. But we went with 'head' elsewhere, and now we're stuck =/ cdef int i + cdef SpanC* span_c = self.span_c() # First, we scan through the Span, and check whether there's a word # with head==0, i.e. a sentence root. If so, we can return it. The # longer the span, the more likely it contains a sentence root, and # in this case we return in linear time. - for i in range(self.c.start, self.c.end): + for i in range(span_c.start, span_c.end): if self.doc.c[i].head == 0: return self.doc[i] # If we don't have a sentence root, we do something that's not so @@ -626,15 +644,15 @@ cdef class Span: # think this should be okay. cdef int current_best = self.doc.length cdef int root = -1 - for i in range(self.c.start, self.c.end): - if self.c.start <= (i+self.doc.c[i].head) < self.c.end: + for i in range(span_c.start, span_c.end): + if span_c.start <= (i+self.doc.c[i].head) < span_c.end: continue words_to_root = _count_words_to_root(&self.doc.c[i], self.doc.length) if words_to_root < current_best: current_best = words_to_root root = i if root == -1: - return self.doc[self.c.start] + return self.doc[span_c.start] else: return self.doc[root] @@ -650,8 +668,9 @@ cdef class Span: the span. RETURNS (Span): The newly constructed object. """ - start_idx += self.c.start_char - end_idx += self.c.start_char + cdef SpanC* span_c = self.span_c() + start_idx += span_c.start_char + end_idx += span_c.start_char return self.doc.char_span(start_idx, end_idx, label=label, kb_id=kb_id, vector=vector) @property @@ -732,60 +751,62 @@ cdef class Span: property start: def __get__(self): - return self.c.start + return self.span_c().start def __set__(self, int start): if start < 0: raise IndexError(Errors.E1032.format(var="start", forbidden="< 0", value=start)) - self.c.start = start + self.span_c().start = start property end: def __get__(self): - return self.c.end + return self.span_c().end def __set__(self, int end): if end < 0: raise IndexError(Errors.E1032.format(var="end", forbidden="< 0", value=end)) - self.c.end = end + self.span_c().end = end property start_char: def __get__(self): - return self.c.start_char + return self.span_c().start_char def __set__(self, int start_char): if start_char < 0: raise IndexError(Errors.E1032.format(var="start_char", forbidden="< 0", value=start_char)) - self.c.start_char = start_char + self.span_c().start_char = start_char property end_char: def __get__(self): - return self.c.end_char + return self.span_c().end_char def __set__(self, int end_char): if end_char < 0: raise IndexError(Errors.E1032.format(var="end_char", forbidden="< 0", value=end_char)) - self.c.end_char = end_char + self.span_c().end_char = end_char property label: def __get__(self): - return self.c.label + return self.span_c().label def __set__(self, attr_t label): - self.c.label = label + self.span_c().label = label property kb_id: def __get__(self): - return self.c.kb_id + return self.span_c().kb_id def __set__(self, attr_t kb_id): - self.c.kb_id = kb_id + self.span_c().kb_id = kb_id property id: def __get__(self): - return self.c.id + cdef SpanC* span_c = self.span_c() + return span_c.id def __set__(self, attr_t id): - self.c.id = id + cdef SpanC* span_c = self.span_c() + span_c.id = id property ent_id: """RETURNS (uint64): The entity ID.""" diff --git a/spacy/tokens/span_group.pxd b/spacy/tokens/span_group.pxd index 5074aa275..6b817578a 100644 --- a/spacy/tokens/span_group.pxd +++ b/spacy/tokens/span_group.pxd @@ -1,3 +1,4 @@ +from libcpp.memory cimport shared_ptr from libcpp.vector cimport vector from ..structs cimport SpanC @@ -5,6 +6,6 @@ cdef class SpanGroup: cdef public object _doc_ref cdef public str name cdef public dict attrs - cdef vector[SpanC] c + cdef vector[shared_ptr[SpanC]] c - cdef void push_back(self, SpanC span) nogil + cdef void push_back(self, const shared_ptr[SpanC] &span) diff --git a/spacy/tokens/span_group.pyx b/spacy/tokens/span_group.pyx index bb0fab24f..7d4d7a248 100644 --- a/spacy/tokens/span_group.pyx +++ b/spacy/tokens/span_group.pyx @@ -6,6 +6,7 @@ import srsly from spacy.errors import Errors from .span cimport Span +from libcpp.memory cimport make_shared cdef class SpanGroup: @@ -187,10 +188,12 @@ cdef class SpanGroup: DOCS: https://spacy.io/api/spangroup#to_bytes """ + cdef SpanC* span_c output = {"name": self.name, "attrs": self.attrs, "spans": []} cdef int i for i in range(self.c.size()): span = self.c[i] + span_c = span.get() # The struct.pack here is probably overkill, but it might help if # you're saving tonnes of spans, and it doesn't really add any # complexity. We do take care to specify little-endian byte order @@ -202,13 +205,13 @@ cdef class SpanGroup: # l: int32_t output["spans"].append(struct.pack( ">QQQllll", - span.id, - span.kb_id, - span.label, - span.start, - span.end, - span.start_char, - span.end_char + span_c.id, + span_c.kb_id, + span_c.label, + span_c.start, + span_c.end, + span_c.start_char, + span_c.end_char )) return srsly.msgpack_dumps(output) @@ -235,10 +238,10 @@ cdef class SpanGroup: span.end = items[4] span.start_char = items[5] span.end_char = items[6] - self.c.push_back(span) + self.c.push_back(make_shared[SpanC](span)) return self - cdef void push_back(self, SpanC span) nogil: + cdef void push_back(self, const shared_ptr[SpanC] &span): self.c.push_back(span) def copy(self) -> SpanGroup: