Support doc.spans in Example.from_dict (#7197)

* add support for spans in Example.from_dict

* add unit tests

* update error to E879
This commit is contained in:
Sofie Van Landeghem 2021-03-02 15:12:54 +01:00 committed by GitHub
parent fb98862337
commit 212f0e779e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 130 additions and 5 deletions

View File

@ -321,7 +321,8 @@ class Errors:
"https://spacy.io/api/top-level#util.filter_spans") "https://spacy.io/api/top-level#util.filter_spans")
E103 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A " E103 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A "
"token can only be part of one entity, so make sure the entities " "token can only be part of one entity, so make sure the entities "
"you're setting don't overlap.") "you're setting don't overlap. To work with overlapping entities, "
"consider using doc.spans instead.")
E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore " E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore "
"settings: {opts}") "settings: {opts}")
E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}") E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}")
@ -487,6 +488,9 @@ class Errors:
# New errors added in v3.x # New errors added in v3.x
E879 = ("Unexpected type for 'spans' data. Provide a dictionary mapping keys to "
"a list of spans, with each span represented by a tuple (start_char, end_char). "
"The tuple can be optionally extended with a label and a KB ID.")
E880 = ("The 'wandb' library could not be found - did you install it? " E880 = ("The 'wandb' library could not be found - did you install it? "
"Alternatively, specify the 'ConsoleLogger' in the 'training.logger' " "Alternatively, specify the 'ConsoleLogger' in the 'training.logger' "
"config section, instead of the 'WandbLogger'.") "config section, instead of the 'WandbLogger'.")

View File

@ -196,6 +196,104 @@ def test_Example_from_dict_with_entities_invalid(annots):
assert len(list(example.reference.ents)) == 0 assert len(list(example.reference.ents)) == 0
@pytest.mark.parametrize(
"annots",
[
{
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
"entities": [
(7, 15, "LOC"),
(11, 15, "LOC"),
(20, 26, "LOC"),
], # overlapping
}
],
)
def test_Example_from_dict_with_entities_overlapping(annots):
vocab = Vocab()
predicted = Doc(vocab, words=annots["words"])
with pytest.raises(ValueError):
Example.from_dict(predicted, annots)
@pytest.mark.parametrize(
"annots",
[
{
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
"spans": {
"cities": [(7, 15, "LOC"), (20, 26, "LOC")],
"people": [(0, 1, "PERSON")],
},
}
],
)
def test_Example_from_dict_with_spans(annots):
vocab = Vocab()
predicted = Doc(vocab, words=annots["words"])
example = Example.from_dict(predicted, annots)
assert len(list(example.reference.ents)) == 0
assert len(list(example.reference.spans["cities"])) == 2
assert len(list(example.reference.spans["people"])) == 1
for span in example.reference.spans["cities"]:
assert span.label_ == "LOC"
for span in example.reference.spans["people"]:
assert span.label_ == "PERSON"
@pytest.mark.parametrize(
"annots",
[
{
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
"spans": {
"cities": [(7, 15, "LOC"), (11, 15, "LOC"), (20, 26, "LOC")],
"people": [(0, 1, "PERSON")],
},
}
],
)
def test_Example_from_dict_with_spans_overlapping(annots):
vocab = Vocab()
predicted = Doc(vocab, words=annots["words"])
example = Example.from_dict(predicted, annots)
assert len(list(example.reference.ents)) == 0
assert len(list(example.reference.spans["cities"])) == 3
assert len(list(example.reference.spans["people"])) == 1
for span in example.reference.spans["cities"]:
assert span.label_ == "LOC"
for span in example.reference.spans["people"]:
assert span.label_ == "PERSON"
@pytest.mark.parametrize(
"annots",
[
{
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
"spans": [(0, 1, "PERSON")],
},
{
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
"spans": {"cities": (7, 15, "LOC")},
},
{
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
"spans": {"cities": [7, 11]},
},
{
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
"spans": {"cities": [[7]]},
},
],
)
def test_Example_from_dict_with_spans_invalid(annots):
vocab = Vocab()
predicted = Doc(vocab, words=annots["words"])
with pytest.raises(ValueError):
Example.from_dict(predicted, annots)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"annots", "annots",
[ [

View File

@ -22,6 +22,8 @@ cpdef Doc annotations_to_doc(vocab, tok_annot, doc_annot):
output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"]) output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"])
if "entities" in doc_annot: if "entities" in doc_annot:
_add_entities_to_doc(output, doc_annot["entities"]) _add_entities_to_doc(output, doc_annot["entities"])
if "spans" in doc_annot:
_add_spans_to_doc(output, doc_annot["spans"])
if array.size: if array.size:
output = output.from_array(attrs, array) output = output.from_array(attrs, array)
# links are currently added with ENT_KB_ID on the token level # links are currently added with ENT_KB_ID on the token level
@ -314,13 +316,11 @@ def _annot2array(vocab, tok_annot, doc_annot):
for key, value in doc_annot.items(): for key, value in doc_annot.items():
if value: if value:
if key == "entities": if key in ["entities", "cats", "spans"]:
pass pass
elif key == "links": elif key == "links":
ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], tok_annot["SPACY"], value) ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], tok_annot["SPACY"], value)
tok_annot["ENT_KB_ID"] = ent_kb_ids tok_annot["ENT_KB_ID"] = ent_kb_ids
elif key == "cats":
pass
else: else:
raise ValueError(Errors.E974.format(obj="doc", key=key)) raise ValueError(Errors.E974.format(obj="doc", key=key))
@ -351,6 +351,29 @@ def _annot2array(vocab, tok_annot, doc_annot):
return attrs, array.T return attrs, array.T
def _add_spans_to_doc(doc, spans_data):
if not isinstance(spans_data, dict):
raise ValueError(Errors.E879)
for key, span_list in spans_data.items():
spans = []
if not isinstance(span_list, list):
raise ValueError(Errors.E879)
for span_tuple in span_list:
if not isinstance(span_tuple, (list, tuple)) or len(span_tuple) < 2:
raise ValueError(Errors.E879)
start_char = span_tuple[0]
end_char = span_tuple[1]
label = 0
kb_id = 0
if len(span_tuple) > 2:
label = span_tuple[2]
if len(span_tuple) > 3:
kb_id = span_tuple[3]
span = doc.char_span(start_char, end_char, label=label, kb_id=kb_id)
spans.append(span)
doc.spans[key] = spans
def _add_entities_to_doc(doc, ner_data): def _add_entities_to_doc(doc, ner_data):
if ner_data is None: if ner_data is None:
return return
@ -397,7 +420,7 @@ def _fix_legacy_dict_data(example_dict):
pass pass
elif key == "ids": elif key == "ids":
pass pass
elif key in ("cats", "links"): elif key in ("cats", "links", "spans"):
doc_dict[key] = value doc_dict[key] = value
elif key in ("ner", "entities"): elif key in ("ner", "entities"):
doc_dict["entities"] = value doc_dict["entities"] = value