Require that all SpanGroup spans are from the current doc (#12569)

* Require that all SpanGroup spans are from the current doc

The restriction on only adding spans from the current doc were already
implemented for all operations except for `SpanGroup.__init__`.

Initialize copied spans for `SpanGroup.copy` with `Doc.char_span` in
order to validate the character offsets and to make it possible to copy
spans between documents with differing tokenization. Currently there is
no validation that the document texts are identical, but the span char
offsets must be valid spans in the target doc, which prevents you from
ending up with completely invalid spans.

* Undo change in test_beam_overfitting_IO
This commit is contained in:
Adriane Boyd 2023-06-01 19:19:17 +02:00 committed by GitHub
parent 05df59fd4a
commit c4112a1da3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 51 additions and 9 deletions

View File

@ -970,6 +970,9 @@ class Errors(metaclass=ErrorsWithCodes):
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` " E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
"or use `auto_select_port=True` to pick an available port automatically.") "or use `auto_select_port=True` to pick an available port automatically.")
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.") E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
E1052 = ("Unable to copy spans: the character offsets for the span at "
"index {i} in the span group do not align with the tokenization "
"in the target doc.")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -93,6 +93,21 @@ def test_span_group_copy(doc):
assert span_group.attrs["key"] == "value" assert span_group.attrs["key"] == "value"
assert list(span_group) != list(clone) assert list(span_group) != list(clone)
# can't copy if the character offsets don't align to tokens
doc2 = Doc(doc.vocab, words=[t.text + "x" for t in doc])
with pytest.raises(ValueError):
span_group.copy(doc=doc2)
# can copy with valid character offsets despite different tokenization
doc3 = doc.copy()
with doc3.retokenize() as retokenizer:
retokenizer.merge(doc3[0:2])
retokenizer.merge(doc3[3:6])
span_group = SpanGroup(doc, spans=[doc[0:6], doc[3:6]])
for span1, span2 in zip(span_group, span_group.copy(doc=doc3)):
assert span1.start_char == span2.start_char
assert span1.end_char == span2.end_char
def test_span_group_set_item(doc, other_doc): def test_span_group_set_item(doc, other_doc):
span_group = doc.spans["SPANS"] span_group = doc.spans["SPANS"]
@ -253,3 +268,12 @@ def test_span_group_typing(doc: Doc):
for i, span in enumerate(span_group): for i, span in enumerate(span_group):
assert span == span_group[i] == spans[i] assert span == span_group[i] == spans[i]
filter_spans(span_group) filter_spans(span_group)
def test_span_group_init_doc(en_tokenizer):
"""Test that all spans must come from the specified doc."""
doc1 = en_tokenizer("a b c")
doc2 = en_tokenizer("a b c")
span_group = SpanGroup(doc1, spans=[doc1[0:1], doc1[1:2]])
with pytest.raises(ValueError):
span_group = SpanGroup(doc1, spans=[doc1[0:1], doc2[1:2]])

View File

@ -728,9 +728,9 @@ def test_neg_annotation(neg_key):
ner.add_label("ORG") ner.add_label("ORG")
example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]}) example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]})
example.reference.spans[neg_key] = [ example.reference.spans[neg_key] = [
Span(neg_doc, 2, 4, "ORG"), Span(example.reference, 2, 4, "ORG"),
Span(neg_doc, 2, 3, "PERSON"), Span(example.reference, 2, 3, "PERSON"),
Span(neg_doc, 1, 4, "PERSON"), Span(example.reference, 1, 4, "PERSON"),
] ]
optimizer = nlp.initialize() optimizer = nlp.initialize()
@ -755,7 +755,7 @@ def test_neg_annotation_conflict(neg_key):
ner.add_label("PERSON") ner.add_label("PERSON")
ner.add_label("LOC") ner.add_label("LOC")
example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]}) example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]})
example.reference.spans[neg_key] = [Span(neg_doc, 2, 4, "PERSON")] example.reference.spans[neg_key] = [Span(example.reference, 2, 4, "PERSON")]
assert len(example.reference.ents) == 1 assert len(example.reference.ents) == 1
assert example.reference.ents[0].text == "Shaka Khan" assert example.reference.ents[0].text == "Shaka Khan"
assert example.reference.ents[0].label_ == "PERSON" assert example.reference.ents[0].label_ == "PERSON"
@ -788,7 +788,7 @@ def test_beam_valid_parse(neg_key):
doc = Doc(nlp.vocab, words=tokens) doc = Doc(nlp.vocab, words=tokens)
example = Example.from_dict(doc, {"ner": iob}) example = Example.from_dict(doc, {"ner": iob})
neg_span = Span(doc, 50, 53, "ORG") neg_span = Span(example.reference, 50, 53, "ORG")
example.reference.spans[neg_key] = [neg_span] example.reference.spans[neg_key] = [neg_span]
optimizer = nlp.initialize() optimizer = nlp.initialize()

View File

@ -438,14 +438,14 @@ def test_score_spans():
return doc.spans[span_key] return doc.spans[span_key]
# Predict exactly the same, but overlapping spans will be discarded # Predict exactly the same, but overlapping spans will be discarded
pred.spans[key] = spans pred.spans[key] = gold.spans[key].copy(doc=pred)
eg = Example(pred, gold) eg = Example(pred, gold)
scores = Scorer.score_spans([eg], attr=key, getter=span_getter) scores = Scorer.score_spans([eg], attr=key, getter=span_getter)
assert scores[f"{key}_p"] == 1.0 assert scores[f"{key}_p"] == 1.0
assert scores[f"{key}_r"] < 1.0 assert scores[f"{key}_r"] < 1.0
# Allow overlapping, now both precision and recall should be 100% # Allow overlapping, now both precision and recall should be 100%
pred.spans[key] = spans pred.spans[key] = gold.spans[key].copy(doc=pred)
eg = Example(pred, gold) eg = Example(pred, gold)
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True) scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True)
assert scores[f"{key}_p"] == 1.0 assert scores[f"{key}_p"] == 1.0

View File

@ -1264,12 +1264,14 @@ cdef class Doc:
other.user_span_hooks = dict(self.user_span_hooks) other.user_span_hooks = dict(self.user_span_hooks)
other.length = self.length other.length = self.length
other.max_length = self.max_length other.max_length = self.max_length
other.spans = self.spans.copy(doc=other)
buff_size = other.max_length + (PADDING*2) buff_size = other.max_length + (PADDING*2)
assert buff_size > 0 assert buff_size > 0
tokens = <TokenC*>other.mem.alloc(buff_size, sizeof(TokenC)) tokens = <TokenC*>other.mem.alloc(buff_size, sizeof(TokenC))
memcpy(tokens, self.c - PADDING, buff_size * sizeof(TokenC)) memcpy(tokens, self.c - PADDING, buff_size * sizeof(TokenC))
other.c = &tokens[PADDING] other.c = &tokens[PADDING]
# copy spans after setting tokens so that SpanGroup.copy can verify
# that the start/end offsets are valid
other.spans = self.spans.copy(doc=other)
return other return other
def to_disk(self, path, *, exclude=tuple()): def to_disk(self, path, *, exclude=tuple()):

View File

@ -52,6 +52,8 @@ cdef class SpanGroup:
if len(spans) : if len(spans) :
self.c.reserve(len(spans)) self.c.reserve(len(spans))
for span in spans: for span in spans:
if doc is not span.doc:
raise ValueError(Errors.E855.format(obj="span"))
self.push_back(span.c) self.push_back(span.c)
def __repr__(self): def __repr__(self):
@ -261,11 +263,22 @@ cdef class SpanGroup:
""" """
if doc is None: if doc is None:
doc = self.doc doc = self.doc
if doc is self.doc:
spans = list(self)
else:
spans = [doc.char_span(span.start_char, span.end_char, label=span.label_, kb_id=span.kb_id, span_id=span.id) for span in self]
for i, span in enumerate(spans):
if span is None:
raise ValueError(Errors.E1052.format(i=i))
if span.kb_id in self.doc.vocab.strings:
doc.vocab.strings.add(span.kb_id_)
if span.id in span.doc.vocab.strings:
doc.vocab.strings.add(span.id_)
return SpanGroup( return SpanGroup(
doc, doc,
name=self.name, name=self.name,
attrs=deepcopy(self.attrs), attrs=deepcopy(self.attrs),
spans=list(self), spans=spans,
) )
def _concat( def _concat(