use original gold object in get_loss function

This commit is contained in:
svlandeg 2019-07-18 13:35:10 +02:00
parent ec55d2fccd
commit e1213eaf6a
2 changed files with 18 additions and 15 deletions

View File

@ -295,11 +295,7 @@ def run_pipeline():
dev_limit = 5000
dev_data = training_set_creator.read_training(
nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit,
kb=el_pipe.kb,
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
)
print("Dev testing from file on", len(dev_data), "articles")
@ -383,9 +379,11 @@ def _measure_baselines(data, kb):
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity in gold.links:
for entity, value in gold.links.items():
start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
# only evaluating on positive examples
if value:
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents:
label = ent.label_

View File

@ -1141,7 +1141,7 @@ class EntityLinker(Pipe):
context_docs = []
entity_encodings = []
cats = []
priors = []
type_vectors = []
@ -1173,12 +1173,9 @@ class EntityLinker(Pipe):
else:
priors.append([0])
cats.append([value])
if len(entity_encodings) > 0:
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
assert len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)
cats = self.model.ops.asarray(cats, dtype="float32")
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
@ -1186,7 +1183,7 @@ class EntityLinker(Pipe):
for i in range(len(entity_encodings))]
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
loss, d_scores = self.get_loss(scores=pred, golds=cats, docs=docs)
loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs)
mention_gradient = bp_mention(d_scores, sgd=sgd)
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
@ -1198,9 +1195,17 @@ class EntityLinker(Pipe):
return 0
def get_loss(self, docs, golds, scores):
d_scores = (scores - golds)
cats = []
for gold in golds:
for entity, value in gold.links.items():
cats.append([value])
cats = self.model.ops.asarray(cats, dtype="float32")
assert len(scores) == len(cats)
d_scores = (scores - cats)
loss = (d_scores ** 2).sum()
loss = loss / len(golds)
loss = loss / len(cats)
return loss, d_scores
def __call__(self, doc):