mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Support doc.spans in Example.from_dict (#7197)
* add support for spans in Example.from_dict * add unit tests * update error to E879
This commit is contained in:
parent
fb98862337
commit
212f0e779e
|
@ -321,7 +321,8 @@ class Errors:
|
||||||
"https://spacy.io/api/top-level#util.filter_spans")
|
"https://spacy.io/api/top-level#util.filter_spans")
|
||||||
E103 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A "
|
E103 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A "
|
||||||
"token can only be part of one entity, so make sure the entities "
|
"token can only be part of one entity, so make sure the entities "
|
||||||
"you're setting don't overlap.")
|
"you're setting don't overlap. To work with overlapping entities, "
|
||||||
|
"consider using doc.spans instead.")
|
||||||
E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore "
|
E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore "
|
||||||
"settings: {opts}")
|
"settings: {opts}")
|
||||||
E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}")
|
E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}")
|
||||||
|
@ -487,6 +488,9 @@ class Errors:
|
||||||
|
|
||||||
# New errors added in v3.x
|
# New errors added in v3.x
|
||||||
|
|
||||||
|
E879 = ("Unexpected type for 'spans' data. Provide a dictionary mapping keys to "
|
||||||
|
"a list of spans, with each span represented by a tuple (start_char, end_char). "
|
||||||
|
"The tuple can be optionally extended with a label and a KB ID.")
|
||||||
E880 = ("The 'wandb' library could not be found - did you install it? "
|
E880 = ("The 'wandb' library could not be found - did you install it? "
|
||||||
"Alternatively, specify the 'ConsoleLogger' in the 'training.logger' "
|
"Alternatively, specify the 'ConsoleLogger' in the 'training.logger' "
|
||||||
"config section, instead of the 'WandbLogger'.")
|
"config section, instead of the 'WandbLogger'.")
|
||||||
|
|
|
@ -196,6 +196,104 @@ def test_Example_from_dict_with_entities_invalid(annots):
|
||||||
assert len(list(example.reference.ents)) == 0
|
assert len(list(example.reference.ents)) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"annots",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||||
|
"entities": [
|
||||||
|
(7, 15, "LOC"),
|
||||||
|
(11, 15, "LOC"),
|
||||||
|
(20, 26, "LOC"),
|
||||||
|
], # overlapping
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_Example_from_dict_with_entities_overlapping(annots):
|
||||||
|
vocab = Vocab()
|
||||||
|
predicted = Doc(vocab, words=annots["words"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Example.from_dict(predicted, annots)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"annots",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||||
|
"spans": {
|
||||||
|
"cities": [(7, 15, "LOC"), (20, 26, "LOC")],
|
||||||
|
"people": [(0, 1, "PERSON")],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_Example_from_dict_with_spans(annots):
|
||||||
|
vocab = Vocab()
|
||||||
|
predicted = Doc(vocab, words=annots["words"])
|
||||||
|
example = Example.from_dict(predicted, annots)
|
||||||
|
assert len(list(example.reference.ents)) == 0
|
||||||
|
assert len(list(example.reference.spans["cities"])) == 2
|
||||||
|
assert len(list(example.reference.spans["people"])) == 1
|
||||||
|
for span in example.reference.spans["cities"]:
|
||||||
|
assert span.label_ == "LOC"
|
||||||
|
for span in example.reference.spans["people"]:
|
||||||
|
assert span.label_ == "PERSON"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"annots",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||||
|
"spans": {
|
||||||
|
"cities": [(7, 15, "LOC"), (11, 15, "LOC"), (20, 26, "LOC")],
|
||||||
|
"people": [(0, 1, "PERSON")],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_Example_from_dict_with_spans_overlapping(annots):
|
||||||
|
vocab = Vocab()
|
||||||
|
predicted = Doc(vocab, words=annots["words"])
|
||||||
|
example = Example.from_dict(predicted, annots)
|
||||||
|
assert len(list(example.reference.ents)) == 0
|
||||||
|
assert len(list(example.reference.spans["cities"])) == 3
|
||||||
|
assert len(list(example.reference.spans["people"])) == 1
|
||||||
|
for span in example.reference.spans["cities"]:
|
||||||
|
assert span.label_ == "LOC"
|
||||||
|
for span in example.reference.spans["people"]:
|
||||||
|
assert span.label_ == "PERSON"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"annots",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||||
|
"spans": [(0, 1, "PERSON")],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||||
|
"spans": {"cities": (7, 15, "LOC")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||||
|
"spans": {"cities": [7, 11]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||||
|
"spans": {"cities": [[7]]},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_Example_from_dict_with_spans_invalid(annots):
|
||||||
|
vocab = Vocab()
|
||||||
|
predicted = Doc(vocab, words=annots["words"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Example.from_dict(predicted, annots)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"annots",
|
"annots",
|
||||||
[
|
[
|
||||||
|
|
|
@ -22,6 +22,8 @@ cpdef Doc annotations_to_doc(vocab, tok_annot, doc_annot):
|
||||||
output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"])
|
output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"])
|
||||||
if "entities" in doc_annot:
|
if "entities" in doc_annot:
|
||||||
_add_entities_to_doc(output, doc_annot["entities"])
|
_add_entities_to_doc(output, doc_annot["entities"])
|
||||||
|
if "spans" in doc_annot:
|
||||||
|
_add_spans_to_doc(output, doc_annot["spans"])
|
||||||
if array.size:
|
if array.size:
|
||||||
output = output.from_array(attrs, array)
|
output = output.from_array(attrs, array)
|
||||||
# links are currently added with ENT_KB_ID on the token level
|
# links are currently added with ENT_KB_ID on the token level
|
||||||
|
@ -314,13 +316,11 @@ def _annot2array(vocab, tok_annot, doc_annot):
|
||||||
|
|
||||||
for key, value in doc_annot.items():
|
for key, value in doc_annot.items():
|
||||||
if value:
|
if value:
|
||||||
if key == "entities":
|
if key in ["entities", "cats", "spans"]:
|
||||||
pass
|
pass
|
||||||
elif key == "links":
|
elif key == "links":
|
||||||
ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], tok_annot["SPACY"], value)
|
ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], tok_annot["SPACY"], value)
|
||||||
tok_annot["ENT_KB_ID"] = ent_kb_ids
|
tok_annot["ENT_KB_ID"] = ent_kb_ids
|
||||||
elif key == "cats":
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(Errors.E974.format(obj="doc", key=key))
|
raise ValueError(Errors.E974.format(obj="doc", key=key))
|
||||||
|
|
||||||
|
@ -351,6 +351,29 @@ def _annot2array(vocab, tok_annot, doc_annot):
|
||||||
return attrs, array.T
|
return attrs, array.T
|
||||||
|
|
||||||
|
|
||||||
|
def _add_spans_to_doc(doc, spans_data):
|
||||||
|
if not isinstance(spans_data, dict):
|
||||||
|
raise ValueError(Errors.E879)
|
||||||
|
for key, span_list in spans_data.items():
|
||||||
|
spans = []
|
||||||
|
if not isinstance(span_list, list):
|
||||||
|
raise ValueError(Errors.E879)
|
||||||
|
for span_tuple in span_list:
|
||||||
|
if not isinstance(span_tuple, (list, tuple)) or len(span_tuple) < 2:
|
||||||
|
raise ValueError(Errors.E879)
|
||||||
|
start_char = span_tuple[0]
|
||||||
|
end_char = span_tuple[1]
|
||||||
|
label = 0
|
||||||
|
kb_id = 0
|
||||||
|
if len(span_tuple) > 2:
|
||||||
|
label = span_tuple[2]
|
||||||
|
if len(span_tuple) > 3:
|
||||||
|
kb_id = span_tuple[3]
|
||||||
|
span = doc.char_span(start_char, end_char, label=label, kb_id=kb_id)
|
||||||
|
spans.append(span)
|
||||||
|
doc.spans[key] = spans
|
||||||
|
|
||||||
|
|
||||||
def _add_entities_to_doc(doc, ner_data):
|
def _add_entities_to_doc(doc, ner_data):
|
||||||
if ner_data is None:
|
if ner_data is None:
|
||||||
return
|
return
|
||||||
|
@ -397,7 +420,7 @@ def _fix_legacy_dict_data(example_dict):
|
||||||
pass
|
pass
|
||||||
elif key == "ids":
|
elif key == "ids":
|
||||||
pass
|
pass
|
||||||
elif key in ("cats", "links"):
|
elif key in ("cats", "links", "spans"):
|
||||||
doc_dict[key] = value
|
doc_dict[key] = value
|
||||||
elif key in ("ner", "entities"):
|
elif key in ("ner", "entities"):
|
||||||
doc_dict["entities"] = value
|
doc_dict["entities"] = value
|
||||||
|
|
Loading…
Reference in New Issue
Block a user