mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
have gold.links correspond exactly to doc.ents
This commit is contained in:
parent
e1213eaf6a
commit
21176517a7
|
@ -397,33 +397,43 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
|||
current_doc = None
|
||||
else:
|
||||
sent = found_ent.sent.as_doc()
|
||||
# currently feeding the gold data one entity per sentence at a time
|
||||
|
||||
gold_start = int(start) - found_ent.sent.start_char
|
||||
gold_end = int(end) - found_ent.sent.start_char
|
||||
|
||||
# add both pos and neg examples (in random order)
|
||||
# this will exclude examples not in the KB
|
||||
if kb:
|
||||
gold_entities = {}
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [c.entity_ for c in candidates]
|
||||
random.shuffle(candidate_ids)
|
||||
for kb_id in candidate_ids:
|
||||
entry = (gold_start, gold_end, kb_id)
|
||||
if kb_id != wd_id:
|
||||
gold_entities[entry] = 0.0
|
||||
gold_entities = {}
|
||||
found_useful = False
|
||||
for ent in sent.ents:
|
||||
if ent.start_char == gold_start and ent.end_char == gold_end:
|
||||
# add both pos and neg examples (in random order)
|
||||
# this will exclude examples not in the KB
|
||||
if kb:
|
||||
value_by_id = {}
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [c.entity_ for c in candidates]
|
||||
random.shuffle(candidate_ids)
|
||||
for kb_id in candidate_ids:
|
||||
found_useful = True
|
||||
if kb_id != wd_id:
|
||||
value_by_id[kb_id] = 0.0
|
||||
else:
|
||||
value_by_id[kb_id] = 1.0
|
||||
gold_entities[(ent.start_char, ent.end_char)] = value_by_id
|
||||
# if no KB, keep all positive examples
|
||||
else:
|
||||
gold_entities[entry] = 1.0
|
||||
# keep all positive examples
|
||||
else:
|
||||
entry = (gold_start, gold_end, wd_id)
|
||||
gold_entities = {entry: 1.0}
|
||||
|
||||
gold = GoldParse(doc=sent, links=gold_entities)
|
||||
data.append((sent, gold))
|
||||
total_entities += 1
|
||||
if len(data) % 2500 == 0:
|
||||
print(" -read", total_entities, "entities")
|
||||
found_useful = True
|
||||
value_by_id = {wd_id: 1.0}
|
||||
gold_entities[(ent.start_char, ent.end_char)] = value_by_id
|
||||
# currently feeding the gold data one entity per sentence at a time
|
||||
# setting all other entities to empty gold dictionary
|
||||
else:
|
||||
gold_entities[(ent.start_char, ent.end_char)] = {}
|
||||
if found_useful:
|
||||
gold = GoldParse(doc=sent, links=gold_entities)
|
||||
data.append((sent, gold))
|
||||
total_entities += 1
|
||||
if len(data) % 2500 == 0:
|
||||
print(" -read", total_entities, "entities")
|
||||
|
||||
print(" -read", total_entities, "entities")
|
||||
return data
|
||||
|
|
|
@ -322,11 +322,12 @@ def _measure_acc(data, el_pipe=None, error_analysis=False):
|
|||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity, value in gold.links.items():
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
start, end, gold_kb = entity
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
|
@ -379,11 +380,12 @@ def _measure_baselines(data, kb):
|
|||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity, value in gold.links.items():
|
||||
start, end, gold_kb = entity
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
label = ent.label_
|
||||
|
@ -487,7 +489,7 @@ def run_el_toy_example(nlp):
|
|||
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
||||
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
||||
"The main character in Doug's novel is the man Arthur Dent, "
|
||||
"but Douglas doesn't write about George Washington or Homer Simpson."
|
||||
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
||||
)
|
||||
doc = nlp(text)
|
||||
print(text)
|
||||
|
|
|
@ -450,10 +450,11 @@ cdef class GoldParse:
|
|||
examples of a label to have the value 0.0. Labels not in the
|
||||
dictionary are treated as missing - the gradient for those labels
|
||||
will be zero.
|
||||
links (dict): A dict with `(start_char, end_char, kb_id)` keys,
|
||||
representing the external ID of an entity in a knowledge base,
|
||||
and the values being either 1.0 or 0.0, indicating positive and
|
||||
negative examples, respectively.
|
||||
links (dict): A dict with `(start_char, end_char)` keys,
|
||||
and the values being dicts with kb_id:value entries,
|
||||
representing the external IDs in a knowledge base (KB)
|
||||
mapped to either 1.0 or 0.0, indicating positive and
|
||||
negative examples respectively.
|
||||
RETURNS (GoldParse): The newly constructed object.
|
||||
"""
|
||||
if words is None:
|
||||
|
|
|
@ -1076,6 +1076,7 @@ class EntityLinker(Pipe):
|
|||
DOCS: TODO
|
||||
"""
|
||||
name = 'entity_linker'
|
||||
NIL = "NIL" # string used to refer to a non-existing link
|
||||
|
||||
@classmethod
|
||||
def Model(cls, **cfg):
|
||||
|
@ -1151,27 +1152,28 @@ class EntityLinker(Pipe):
|
|||
ents_by_offset = dict()
|
||||
for ent in doc.ents:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||
for entity, value in gold.links.items():
|
||||
start, end, kb_id = entity
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
mention = doc.text[start:end]
|
||||
entity_encoding = self.kb.get_vector(kb_id)
|
||||
prior_prob = self.kb.get_prior_prob(kb_id, mention)
|
||||
for kb_id, value in kb_dict.items():
|
||||
entity_encoding = self.kb.get_vector(kb_id)
|
||||
prior_prob = self.kb.get_prior_prob(kb_id, mention)
|
||||
|
||||
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
|
||||
assert gold_ent is not None
|
||||
type_vector = [0 for i in range(len(type_to_int))]
|
||||
if len(type_to_int) > 0:
|
||||
type_vector[type_to_int[gold_ent.label_]] = 1
|
||||
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
|
||||
assert gold_ent is not None
|
||||
type_vector = [0 for i in range(len(type_to_int))]
|
||||
if len(type_to_int) > 0:
|
||||
type_vector[type_to_int[gold_ent.label_]] = 1
|
||||
|
||||
# store data
|
||||
entity_encodings.append(entity_encoding)
|
||||
context_docs.append(doc)
|
||||
type_vectors.append(type_vector)
|
||||
# store data
|
||||
entity_encodings.append(entity_encoding)
|
||||
context_docs.append(doc)
|
||||
type_vectors.append(type_vector)
|
||||
|
||||
if self.cfg.get("prior_weight", 1) > 0:
|
||||
priors.append([prior_prob])
|
||||
else:
|
||||
priors.append([0])
|
||||
if self.cfg.get("prior_weight", 1) > 0:
|
||||
priors.append([prior_prob])
|
||||
else:
|
||||
priors.append([0])
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)
|
||||
|
@ -1197,8 +1199,9 @@ class EntityLinker(Pipe):
|
|||
def get_loss(self, docs, golds, scores):
|
||||
cats = []
|
||||
for gold in golds:
|
||||
for entity, value in gold.links.items():
|
||||
cats.append([value])
|
||||
for entity, kb_dict in gold.links.items():
|
||||
for kb_id, value in kb_dict.items():
|
||||
cats.append([value])
|
||||
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
assert len(scores) == len(cats)
|
||||
|
@ -1209,26 +1212,27 @@ class EntityLinker(Pipe):
|
|||
return loss, d_scores
|
||||
|
||||
def __call__(self, doc):
|
||||
entities, kb_ids = self.predict([doc])
|
||||
self.set_annotations([doc], entities, kb_ids)
|
||||
kb_ids = self.predict([doc])
|
||||
self.set_annotations([doc], kb_ids)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
docs = list(docs)
|
||||
entities, kb_ids = self.predict(docs)
|
||||
self.set_annotations(docs, entities, kb_ids)
|
||||
kb_ids = self.predict(docs)
|
||||
self.set_annotations(docs, kb_ids)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
|
||||
self.require_model()
|
||||
self.require_kb()
|
||||
|
||||
final_entities = []
|
||||
entity_count = 0
|
||||
final_kb_ids = []
|
||||
|
||||
if not docs:
|
||||
return final_entities, final_kb_ids
|
||||
return final_kb_ids
|
||||
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
@ -1242,12 +1246,15 @@ class EntityLinker(Pipe):
|
|||
if len(doc) > 0:
|
||||
context_encoding = context_encodings[i]
|
||||
for ent in doc.ents:
|
||||
entity_count += 1
|
||||
type_vector = [0 for i in range(len(type_to_int))]
|
||||
if len(type_to_int) > 0:
|
||||
type_vector[type_to_int[ent.label_]] = 1
|
||||
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if candidates:
|
||||
if not candidates:
|
||||
final_kb_ids.append(self.NIL) # no prediction possible for this entity
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
||||
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
|
||||
|
@ -1266,15 +1273,20 @@ class EntityLinker(Pipe):
|
|||
# TODO: thresholding
|
||||
best_index = scores.argmax()
|
||||
best_candidate = candidates[best_index]
|
||||
final_entities.append(ent)
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
|
||||
return final_entities, final_kb_ids
|
||||
assert len(final_kb_ids) == entity_count
|
||||
|
||||
def set_annotations(self, docs, entities, kb_ids=None):
|
||||
for entity, kb_id in zip(entities, kb_ids):
|
||||
for token in entity:
|
||||
token.ent_kb_id_ = kb_id
|
||||
return final_kb_ids
|
||||
|
||||
def set_annotations(self, docs, kb_ids, tensors=None):
|
||||
i=0
|
||||
for doc in docs:
|
||||
for ent in doc.ents:
|
||||
kb_id = kb_ids[i]
|
||||
i += 1
|
||||
for token in ent:
|
||||
token.ent_kb_id_ = kb_id
|
||||
|
||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||
serialize = OrderedDict()
|
||||
|
|
Loading…
Reference in New Issue
Block a user