Fix conversion of NER data

This commit is contained in:
Matthew Honnibal 2020-06-23 23:58:27 +02:00
parent b82431207d
commit 78e9e15e9e

View File

@ -3,7 +3,8 @@ import srsly
from .. import util
from ..errors import Warnings
from ..tokens import Doc
from .iob_utils import biluo_tags_from_offsets
from .iob_utils import biluo_tags_from_offsets, tags_to_entities
import json
def merge_sents(sents):
@ -97,6 +98,7 @@ def json_to_annotations(doc):
spaces = []
ids = []
tags = []
ner_tags = []
pos = []
morphs = []
lemmas = []
@ -110,21 +112,22 @@ def json_to_annotations(doc):
words.append(token["orth"])
spaces.append(token.get("space", True))
ids.append(token.get('id', sent_start_i + i))
if "tag" in token:
tags.append(token["tag"])
if "pos" in token:
pos.append(token["pos"])
if "morph" in token:
morphs.append(token["morph"])
if "lemma" in token:
lemmas.append(token["lemma"])
tags.append(token.get("tag", None))
pos.append(token.get("pos", None))
morphs.append(token.get("morph", None))
lemmas.append(token.get("lemma", None))
if "head" in token:
heads.append(token["head"] + sent_start_i + i)
else:
heads.append(None)
if "dep" in token:
labels.append(token["dep"])
# Ensure ROOT label is case-insensitive
if labels[-1].lower() == "root":
labels[-1] = "ROOT"
else:
labels.append(None)
ner_tags.append(token.get("ner", None))
if i == 0:
sent_starts.append(1)
else:
@ -142,31 +145,25 @@ def json_to_annotations(doc):
brackets=brackets
)
# avoid including dummy values that looks like gold info was present
if tags:
if any(tags):
example["token_annotation"]["tags"] = tags
if pos:
if any(pos):
example["token_annotation"]["pos"] = pos
if morphs:
if any(morphs):
example["token_annotation"]["morphs"] = morphs
if lemmas:
if any(lemmas):
example["token_annotation"]["lemmas"] = lemmas
if heads:
if any(head is not None for head in heads):
example["token_annotation"]["heads"] = heads
if labels:
if any(labels):
example["token_annotation"]["deps"] = labels
if pos:
example["token_annotation"]["pos"] = pos
cats = {}
for cat in paragraph.get("cats", {}):
cats[cat["label"]] = cat["value"]
entities = []
for start, end, label in paragraph.get("entities", {}):
ent_tuple = (start, end, label)
entities.append(ent_tuple)
example["doc_annotation"] = dict(
cats=cats,
entities=entities,
entities=ner_tags,
links=paragraph.get("links", []) # TODO: fix/test
)
yield example