mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-11 23:05:50 +03:00
fixed some tests + WIP roundtrip unit test
This commit is contained in:
parent
43d41d6bb6
commit
ba80ad7efd
|
@ -177,17 +177,19 @@ def _annot2array(vocab, tok_annot, doc_annot):
|
|||
|
||||
for key, value in doc_annot.items():
|
||||
if key == "entities":
|
||||
words = tok_annot["ORTH"]
|
||||
spaces = tok_annot["SPACY"]
|
||||
ent_iobs, ent_types = _parse_ner_tags(value, vocab, words, spaces)
|
||||
tok_annot["ENT_IOB"] = ent_iobs
|
||||
tok_annot["ENT_TYPE"] = ent_types
|
||||
if value:
|
||||
words = tok_annot["ORTH"]
|
||||
spaces = tok_annot["SPACY"]
|
||||
ent_iobs, ent_types = _parse_ner_tags(value, vocab, words, spaces)
|
||||
tok_annot["ENT_IOB"] = ent_iobs
|
||||
tok_annot["ENT_TYPE"] = ent_types
|
||||
elif key == "links":
|
||||
entities = doc_annot.get("entities", {})
|
||||
if value and not entities:
|
||||
raise ValueError(Errors.E981)
|
||||
ent_kb_ids = _parse_links(vocab, words, value, entities)
|
||||
tok_annot["ENT_KB_ID"] = ent_kb_ids
|
||||
if value:
|
||||
entities = doc_annot.get("entities", {})
|
||||
if value and not entities:
|
||||
raise ValueError(Errors.E981)
|
||||
ent_kb_ids = _parse_links(vocab, words, value, entities)
|
||||
tok_annot["ENT_KB_ID"] = ent_kb_ids
|
||||
elif key == "cats":
|
||||
pass
|
||||
else:
|
||||
|
|
|
@ -38,10 +38,16 @@ def docs_to_json(docs, id=0, ner_missing_tag="O"):
|
|||
docs = [docs]
|
||||
json_doc = {"id": id, "paragraphs": []}
|
||||
for i, doc in enumerate(docs):
|
||||
json_para = {'raw': doc.text, "sentences": [], "cats": []}
|
||||
json_para = {'raw': doc.text, "sentences": [], "cats": [], "entities": [], "links": []}
|
||||
for cat, val in doc.cats.items():
|
||||
json_cat = {"label": cat, "value": val}
|
||||
json_para["cats"].append(json_cat)
|
||||
for ent in doc.ents:
|
||||
ent_tuple = (ent.start_char, ent.end_char, ent.label_)
|
||||
json_para["entities"].append(ent_tuple)
|
||||
if ent.kb_id_:
|
||||
link_dict = {(ent.start_char, ent.end_char): {ent.kb_id_: 1.0}}
|
||||
json_para["links"].append(link_dict)
|
||||
ent_offsets = [(e.start_char, e.end_char, e.label_) for e in doc.ents]
|
||||
biluo_tags = biluo_tags_from_offsets(doc, ent_offsets, missing=ner_missing_tag)
|
||||
for j, sent in enumerate(doc.sents):
|
||||
|
@ -56,7 +62,6 @@ def docs_to_json(docs, id=0, ner_missing_tag="O"):
|
|||
if doc.is_parsed:
|
||||
json_token["head"] = token.head.i-token.i
|
||||
json_token["dep"] = token.dep_
|
||||
json_token["ner"] = biluo_tags[token.i]
|
||||
json_sent["tokens"].append(json_token)
|
||||
json_para["sentences"].append(json_sent)
|
||||
json_doc["paragraphs"].append(json_para)
|
||||
|
@ -78,7 +83,7 @@ def read_json_file(loc, docs_filter=None, limit=None):
|
|||
|
||||
def json_to_annotations(doc):
|
||||
"""Convert an item in the JSON-formatted training data to the format
|
||||
used by GoldParse.
|
||||
used by Example.
|
||||
|
||||
doc (dict): One entry in the training data.
|
||||
YIELDS (tuple): The reformatted data - one training example per paragraph
|
||||
|
@ -93,7 +98,6 @@ def json_to_annotations(doc):
|
|||
lemmas = []
|
||||
heads = []
|
||||
labels = []
|
||||
ner = []
|
||||
sent_starts = []
|
||||
brackets = []
|
||||
for sent in paragraph["sentences"]:
|
||||
|
@ -110,7 +114,6 @@ def json_to_annotations(doc):
|
|||
# Ensure ROOT label is case-insensitive
|
||||
if labels[-1].lower() == "root":
|
||||
labels[-1] = "ROOT"
|
||||
ner.append(token.get("ner", "-"))
|
||||
if i == 0:
|
||||
sent_starts.append(1)
|
||||
else:
|
||||
|
@ -119,9 +122,7 @@ def json_to_annotations(doc):
|
|||
brackets.extend((b["first"] + sent_start_i,
|
||||
b["last"] + sent_start_i, b["label"])
|
||||
for b in sent["brackets"])
|
||||
cats = {}
|
||||
for cat in paragraph.get("cats", {}):
|
||||
cats[cat["label"]] = cat["value"]
|
||||
|
||||
example["token_annotation"] = dict(
|
||||
ids=ids,
|
||||
words=words,
|
||||
|
@ -131,15 +132,24 @@ def json_to_annotations(doc):
|
|||
lemmas=lemmas,
|
||||
heads=heads,
|
||||
deps=labels,
|
||||
entities=ner,
|
||||
sent_starts=sent_starts,
|
||||
brackets=brackets
|
||||
)
|
||||
example["doc_annotation"] = dict(cats=cats)
|
||||
|
||||
cats = {}
|
||||
for cat in paragraph.get("cats", {}):
|
||||
cats[cat["label"]] = cat["value"]
|
||||
entities = []
|
||||
for start, end, label in paragraph.get("entities", {}):
|
||||
ent_tuple = (start, end, label)
|
||||
entities.append(ent_tuple)
|
||||
example["doc_annotation"] = dict(
|
||||
cats=cats,
|
||||
entities=entities,
|
||||
links=paragraph.get("links", []) # TODO: fix/test
|
||||
)
|
||||
yield example
|
||||
|
||||
|
||||
|
||||
def json_iterate(loc):
|
||||
# We should've made these files jsonl...But since we didn't, parse out
|
||||
# the docs one-by-one to reduce memory usage.
|
||||
|
|
|
@ -254,16 +254,15 @@ def test_iob_to_biluo():
|
|||
def test_roundtrip_docs_to_json(doc):
|
||||
nlp = English()
|
||||
text = doc.text
|
||||
idx = [t.idx for t in doc]
|
||||
tags = [t.tag_ for t in doc]
|
||||
pos = [t.pos_ for t in doc]
|
||||
morphs = [t.morph_ for t in doc]
|
||||
lemmas = [t.lemma_ for t in doc]
|
||||
deps = [t.dep_ for t in doc]
|
||||
heads = [t.head.i for t in doc]
|
||||
biluo_tags = iob_to_biluo(
|
||||
[t.ent_iob_ + "-" + t.ent_type_ if t.ent_type_ else "O" for t in doc]
|
||||
)
|
||||
cats = doc.cats
|
||||
ents = doc.ents
|
||||
|
||||
# roundtrip to JSON
|
||||
with make_tempdir() as tmpdir:
|
||||
|
@ -274,12 +273,14 @@ def test_roundtrip_docs_to_json(doc):
|
|||
reloaded_example = next(goldcorpus.dev_dataset(nlp=nlp))
|
||||
assert len(doc) == goldcorpus.count_train()
|
||||
assert text == reloaded_example.predicted.text
|
||||
assert idx == [t.idx for t in reloaded_example.reference]
|
||||
assert tags == [t.tag_ for t in reloaded_example.reference]
|
||||
assert pos == [t.pos_ for t in reloaded_example.reference]
|
||||
assert morphs == [t.morph_ for t in reloaded_example.reference]
|
||||
assert lemmas == [t.lemma_ for t in reloaded_example.reference]
|
||||
assert deps == [t.dep_ for t in reloaded_example.reference]
|
||||
assert heads == [t.head.i for t in reloaded_example.reference]
|
||||
assert ents == reloaded_example.reference.ents
|
||||
assert "TRAVEL" in reloaded_example.reference.cats
|
||||
assert "BAKING" in reloaded_example.reference.cats
|
||||
assert cats["TRAVEL"] == reloaded_example.reference.cats["TRAVEL"]
|
||||
|
@ -394,25 +395,28 @@ def test_align(tokens_a, tokens_b, expected):
|
|||
def test_goldparse_startswith_space(en_tokenizer):
|
||||
text = " a"
|
||||
doc = en_tokenizer(text)
|
||||
g = GoldParse(doc, words=["a"], entities=["U-DATE"], deps=["ROOT"], heads=[0])
|
||||
assert g.words == [" ", "a"]
|
||||
assert g.ner == [None, "U-DATE"]
|
||||
assert g.labels == [None, "ROOT"]
|
||||
gold_words = ["a"]
|
||||
entities = ["U-DATE"]
|
||||
deps = ["ROOT"]
|
||||
heads = [0]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities, "deps":deps, "heads": heads})
|
||||
assert example.get_aligned("ENT_IOB") == [None, 3]
|
||||
assert example.get_aligned("ENT_TYPE", as_string=True) == [None, "DATE"]
|
||||
assert example.get_aligned("DEP", as_string=True) == [None, "ROOT"]
|
||||
|
||||
|
||||
def test_gold_constructor():
|
||||
"""Test that the GoldParse constructor works fine"""
|
||||
"""Test that the Example constructor works fine"""
|
||||
nlp = English()
|
||||
doc = nlp("This is a sentence")
|
||||
gold = GoldParse(doc, cats={"cat1": 1.0, "cat2": 0.0})
|
||||
|
||||
assert gold.cats["cat1"]
|
||||
assert not gold.cats["cat2"]
|
||||
assert gold.words == ["This", "is", "a", "sentence"]
|
||||
example = Example.from_dict(doc, {"cats": {"cat1": 1.0, "cat2": 0.0}})
|
||||
assert example.get_aligned("ORTH", as_string=True) == ["This", "is", "a", "sentence"]
|
||||
assert example.reference.cats["cat1"]
|
||||
assert not example.reference.cats["cat2"]
|
||||
|
||||
|
||||
def test_tuple_format_implicit():
|
||||
"""Test tuple format with implicit GoldParse creation"""
|
||||
"""Test tuple format"""
|
||||
|
||||
train_data = [
|
||||
("Uber blew through $1 million a week", {"entities": [(0, 4, "ORG")]}),
|
||||
|
@ -428,7 +432,7 @@ def test_tuple_format_implicit():
|
|||
|
||||
@pytest.mark.xfail # TODO
|
||||
def test_tuple_format_implicit_invalid():
|
||||
"""Test that an error is thrown for an implicit invalid GoldParse field"""
|
||||
"""Test that an error is thrown for an implicit invalid field"""
|
||||
|
||||
train_data = [
|
||||
("Uber blew through $1 million a week", {"frumble": [(0, 4, "ORG")]}),
|
||||
|
|
|
@ -169,6 +169,23 @@ def test_Example_from_dict_with_entities(annots):
|
|||
assert example.reference[5].ent_type_ == "LOC"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"annots",
|
||||
[
|
||||
{
|
||||
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||
"entities": [(0, 4, "LOC"), (21, 27, "LOC")], # not aligned to token boundaries
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_Example_from_dict_with_entities_invalid(annots):
|
||||
vocab = Vocab()
|
||||
predicted = Doc(vocab, words=annots["words"])
|
||||
example = Example.from_dict(predicted, annots)
|
||||
# TODO: shouldn't this throw some sort of warning ?
|
||||
assert len(list(example.reference.ents)) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"annots",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue
Block a user