start testing get_aligned

This commit is contained in:
svlandeg 2020-06-15 17:16:01 +02:00
parent fd5f199feb
commit 41d29983a7
3 changed files with 27 additions and 24 deletions

View File

@ -620,10 +620,6 @@ class Errors(object):
E997 = ("Tokenizer special cases are not allowed to modify the text. " E997 = ("Tokenizer special cases are not allowed to modify the text. "
"This would map '{chunk}' to '{orth}' given token attributes " "This would map '{chunk}' to '{orth}' given token attributes "
"'{token_attrs}'.") "'{token_attrs}'.")
E998 = ("To create GoldParse objects from Example objects without a "
"Doc, get_gold_parses() should be called with a Vocab object.")
E999 = ("Encountered an unexpected format for the dictionary holding "
"gold annotations: {gold_dict}")
@add_codes @add_codes

View File

@ -9,11 +9,11 @@ from .align import Alignment
from ..errors import Errors, AlignmentError from ..errors import Errors, AlignmentError
cpdef Doc annotations2doc(Doc predicted, tok_annot, doc_annot): cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
# TODO: Improve and test this """ Create a Doc from dictionaries with token and doc annotations. Assumes ORTH is set. """
words = tok_annot.get("ORTH", [tok.text for tok in predicted]) words = tok_annot["ORTH"]
attrs, array = _annot2array(predicted, tok_annot, doc_annot) attrs, array = _annot2array(vocab, tok_annot, doc_annot)
output = Doc(predicted.vocab, words=words) output = Doc(vocab, words=words)
if array.size: if array.size:
output = output.from_array(attrs, array) output = output.from_array(attrs, array)
output.cats.update(doc_annot.get("cats", {})) output.cats.update(doc_annot.get("cats", {}))
@ -23,6 +23,7 @@ cpdef Doc annotations2doc(Doc predicted, tok_annot, doc_annot):
cdef class Example: cdef class Example:
def __init__(self, Doc predicted, Doc reference, *, Alignment alignment=None): def __init__(self, Doc predicted, Doc reference, *, Alignment alignment=None):
""" Doc can either be text, or an actual Doc """ """ Doc can either be text, or an actual Doc """
assert predicted.vocab is reference.vocab
msg = "Example.__init__ got None for '{arg}'. Requires Doc." msg = "Example.__init__ got None for '{arg}'. Requires Doc."
if predicted is None: if predicted is None:
raise TypeError(msg.format(arg="predicted")) raise TypeError(msg.format(arg="predicted"))
@ -52,11 +53,13 @@ cdef class Example:
raise ValueError("Example.from_dict expected dict, received None") raise ValueError("Example.from_dict expected dict, received None")
if not isinstance(predicted, Doc): if not isinstance(predicted, Doc):
raise TypeError(f"Argument 1 should be Doc. Got {type(predicted)}") raise TypeError(f"Argument 1 should be Doc. Got {type(predicted)}")
example_dict = _fix_legacy_dict_data(predicted, example_dict) example_dict = _fix_legacy_dict_data(example_dict)
tok_dict, doc_dict = _parse_example_dict_data(example_dict) tok_dict, doc_dict = _parse_example_dict_data(example_dict)
if "ORTH" not in tok_dict:
tok_dict["ORTH"] = [tok.text for tok in predicted]
return Example( return Example(
predicted, predicted,
annotations2doc(predicted, tok_dict, doc_dict) annotations2doc(predicted.vocab, tok_dict, doc_dict)
) )
@property @property
@ -78,6 +81,7 @@ cdef class Example:
gold_to_cand = alignment.gold_to_cand gold_to_cand = alignment.gold_to_cand
cand_to_gold = alignment.cand_to_gold cand_to_gold = alignment.cand_to_gold
vocab = self.reference.vocab
gold_values = self.reference.to_array([field]) gold_values = self.reference.to_array([field])
output = [] output = []
for i, gold_i in enumerate(cand_to_gold): for i, gold_i in enumerate(cand_to_gold):
@ -85,11 +89,11 @@ cdef class Example:
output.append(None) output.append(None)
elif gold_i is None: elif gold_i is None:
if i in i2j_multi: if i in i2j_multi:
output.append(gold_values[i2j_multi[i]]) output.append(vocab.strings[gold_values[i2j_multi[i]]])
else: else:
output.append(None) output.append(None)
else: else:
output.append(gold_values[gold_i]) output.append(vocab.strings[gold_values[gold_i]])
return output return output
def to_dict(self): def to_dict(self):
@ -139,21 +143,21 @@ cdef class Example:
return self.x.text return self.x.text
def _annot2array(predicted, tok_annot, doc_annot): def _annot2array(vocab, tok_annot, doc_annot):
attrs = [] attrs = []
values = [] values = []
for key, value in doc_annot.items(): for key, value in doc_annot.items():
if key == "entities": if key == "entities":
words = tok_annot.get("ORTH", [tok.text for tok in predicted]) words = tok_annot["ORTH"]
ent_iobs, ent_types = _parse_ner_tags(predicted.vocab, words, value) ent_iobs, ent_types = _parse_ner_tags(vocab, words, value)
tok_annot["ENT_IOB"] = ent_iobs tok_annot["ENT_IOB"] = ent_iobs
tok_annot["ENT_TYPE"] = ent_types tok_annot["ENT_TYPE"] = ent_types
elif key == "links": elif key == "links":
entities = doc_annot.get("entities", {}) entities = doc_annot.get("entities", {})
if value and not entities: if value and not entities:
raise ValueError(Errors.E981) raise ValueError(Errors.E981)
ent_kb_ids = _parse_links(predicted.vocab, words, value, entities) ent_kb_ids = _parse_links(vocab, words, value, entities)
tok_annot["ENT_KB_ID"] = ent_kb_ids tok_annot["ENT_KB_ID"] = ent_kb_ids
elif key == "cats": elif key == "cats":
pass pass
@ -173,7 +177,7 @@ def _annot2array(predicted, tok_annot, doc_annot):
values.append(value) values.append(value)
elif key == "MORPH": elif key == "MORPH":
attrs.append(key) attrs.append(key)
values.append([predicted.vocab.morphology.add(v) for v in value]) values.append([vocab.morphology.add(v) for v in value])
elif key == "ENT_IOB": elif key == "ENT_IOB":
iob_strings = Token.iob_strings() iob_strings = Token.iob_strings()
attrs.append(key) attrs.append(key)
@ -183,7 +187,7 @@ def _annot2array(predicted, tok_annot, doc_annot):
raise ValueError(Errors.E982.format(values=iob_strings, value=values)) raise ValueError(Errors.E982.format(values=iob_strings, value=values))
else: else:
attrs.append(key) attrs.append(key)
values.append([predicted.vocab.strings.add(v) for v in value]) values.append([vocab.strings.add(v) for v in value])
array = numpy.asarray(values, dtype="uint64") array = numpy.asarray(values, dtype="uint64")
return attrs, array.T return attrs, array.T
@ -196,7 +200,7 @@ def _parse_example_dict_data(example_dict):
) )
def _fix_legacy_dict_data(predicted, example_dict): def _fix_legacy_dict_data(example_dict):
token_dict = example_dict.get("token_annotation", {}) token_dict = example_dict.get("token_annotation", {})
doc_dict = example_dict.get("doc_annotation", {}) doc_dict = example_dict.get("doc_annotation", {})
for key, value in example_dict.items(): for key, value in example_dict.items():

View File

@ -28,17 +28,20 @@ def test_Example_from_dict_basic():
def test_Example_from_dict_invalid(annots): def test_Example_from_dict_invalid(annots):
vocab = Vocab() vocab = Vocab()
predicted = Doc(vocab, words=annots["words"]) predicted = Doc(vocab, words=annots["words"])
with pytest.raises(ValueError): with pytest.raises(KeyError):
Example.from_dict(predicted, annots) Example.from_dict(predicted, annots)
@pytest.mark.parametrize("annots", [{"words": ["ice", "cream"], "tags": ["NN", "NN"]}]) @pytest.mark.parametrize("gold_words", [["ice", "cream"], ["icecream"], ["i", "ce", "cream"]])
def test_Example_from_dict_with_tags(annots): @pytest.mark.parametrize("annots", [{"words": ["icecream"], "tags": ["NN"]}])
def test_Example_from_dict_with_tags(gold_words, annots):
vocab = Vocab() vocab = Vocab()
predicted = Doc(vocab, words=annots["words"]) predicted = Doc(vocab, words=gold_words)
example = Example.from_dict(predicted, annots) example = Example.from_dict(predicted, annots)
for i, token in enumerate(example.reference): for i, token in enumerate(example.reference):
assert token.tag_ == annots["tags"][i] assert token.tag_ == annots["tags"][i]
aligned_tags = example.get_aligned("tag")
assert aligned_tags == ["NN" for _ in predicted]
@pytest.mark.parametrize( @pytest.mark.parametrize(