Example Dict format consistency (#5858)

* consistently use upper-case IDS in token_annotation format and for get_aligned

* remove ID from to_dict (not used in from_dict either)

* fix test

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
Sofie Van Landeghem 2020-08-04 22:22:26 +02:00 committed by GitHub
parent fa79a0db9f
commit 34873c4911
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 32 additions and 24 deletions

View File

@ -183,15 +183,15 @@ cdef class Example:
"links": self._links_to_dict()
},
"token_annotation": {
"ids": [t.i+1 for t in self.reference],
"words": [t.text for t in self.reference],
"tags": [t.tag_ for t in self.reference],
"lemmas": [t.lemma_ for t in self.reference],
"pos": [t.pos_ for t in self.reference],
"morphs": [t.morph_ for t in self.reference],
"heads": [t.head.i for t in self.reference],
"deps": [t.dep_ for t in self.reference],
"sent_starts": [int(bool(t.is_sent_start)) for t in self.reference]
"ORTH": [t.text for t in self.reference],
"SPACY": [bool(t.whitespace_) for t in self.reference],
"TAG": [t.tag_ for t in self.reference],
"LEMMA": [t.lemma_ for t in self.reference],
"POS": [t.pos_ for t in self.reference],
"MORPH": [t.morph_ for t in self.reference],
"HEAD": [t.head.i for t in self.reference],
"DEP": [t.dep_ for t in self.reference],
"SENT_START": [int(bool(t.is_sent_start)) for t in self.reference]
}
}
@ -335,10 +335,14 @@ def _fix_legacy_dict_data(example_dict):
for key, value in old_token_dict.items():
if key in ("text", "ids", "brackets"):
pass
elif key in remapping.values():
token_dict[key] = value
elif key.lower() in remapping:
token_dict[remapping[key.lower()]] = value
else:
raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=remapping.keys()))
all_keys = set(remapping.values())
all_keys.update(remapping.keys())
raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=all_keys))
text = example_dict.get("text", example_dict.get("raw"))
if _has_field(token_dict, "ORTH") and not _has_field(token_dict, "SPACY"):
token_dict["SPACY"] = _guess_spaces(text, token_dict["ORTH"])

View File

@ -108,7 +108,7 @@ class SentenceRecognizer(Tagger):
truths = []
for eg in examples:
eg_truth = []
for x in eg.get_aligned("sent_start"):
for x in eg.get_aligned("SENT_START"):
if x is None:
eg_truth.append(None)
elif x == 1:

View File

@ -259,7 +259,7 @@ class Tagger(Pipe):
DOCS: https://spacy.io/api/tagger#get_loss
"""
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
truths = [eg.get_aligned("tag", as_string=True) for eg in examples]
truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
d_scores, loss = loss_func(scores, truths)
if self.model.ops.xp.isnan(loss):
raise ValueError("nan value when computing loss")

View File

@ -646,14 +646,14 @@ def test_split_sents(merged_dict):
assert split_examples[1].text == "It is just me"
token_annotation_1 = split_examples[0].to_dict()["token_annotation"]
assert token_annotation_1["words"] == ["Hi", "there", "everyone"]
assert token_annotation_1["tags"] == ["INTJ", "ADV", "PRON"]
assert token_annotation_1["sent_starts"] == [1, 0, 0]
assert token_annotation_1["ORTH"] == ["Hi", "there", "everyone"]
assert token_annotation_1["TAG"] == ["INTJ", "ADV", "PRON"]
assert token_annotation_1["SENT_START"] == [1, 0, 0]
token_annotation_2 = split_examples[1].to_dict()["token_annotation"]
assert token_annotation_2["words"] == ["It", "is", "just", "me"]
assert token_annotation_2["tags"] == ["PRON", "AUX", "ADV", "PRON"]
assert token_annotation_2["sent_starts"] == [1, 0, 0, 0]
assert token_annotation_2["ORTH"] == ["It", "is", "just", "me"]
assert token_annotation_2["TAG"] == ["PRON", "AUX", "ADV", "PRON"]
assert token_annotation_2["SENT_START"] == [1, 0, 0, 0]
def test_alignment():

View File

@ -42,7 +42,7 @@ def test_Example_from_dict_with_tags(pred_words, annots):
example = Example.from_dict(predicted, annots)
for i, token in enumerate(example.reference):
assert token.tag_ == annots["tags"][i]
aligned_tags = example.get_aligned("tag", as_string=True)
aligned_tags = example.get_aligned("TAG", as_string=True)
assert aligned_tags == ["NN" for _ in predicted]
@ -53,9 +53,13 @@ def test_aligned_tags():
annots = {"words": gold_words, "tags": gold_tags}
vocab = Vocab()
predicted = Doc(vocab, words=pred_words)
example = Example.from_dict(predicted, annots)
aligned_tags = example.get_aligned("tag", as_string=True)
assert aligned_tags == ["VERB", "DET", "NOUN", "SCONJ", "PRON", "VERB", "VERB"]
example1 = Example.from_dict(predicted, annots)
aligned_tags1 = example1.get_aligned("TAG", as_string=True)
assert aligned_tags1 == ["VERB", "DET", "NOUN", "SCONJ", "PRON", "VERB", "VERB"]
# ensure that to_dict works correctly
example2 = Example.from_dict(predicted, example1.to_dict())
aligned_tags2 = example2.get_aligned("TAG", as_string=True)
assert aligned_tags2 == ["VERB", "DET", "NOUN", "SCONJ", "PRON", "VERB", "VERB"]
def test_aligned_tags_multi():
@ -66,7 +70,7 @@ def test_aligned_tags_multi():
vocab = Vocab()
predicted = Doc(vocab, words=pred_words)
example = Example.from_dict(predicted, annots)
aligned_tags = example.get_aligned("tag", as_string=True)
aligned_tags = example.get_aligned("TAG", as_string=True)
assert aligned_tags == [None, None, "SCONJ", "PRON", "VERB", "VERB"]