mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
use original gold object in get_loss function
This commit is contained in:
parent
ec55d2fccd
commit
e1213eaf6a
|
@ -295,11 +295,7 @@ def run_pipeline():
|
|||
|
||||
dev_limit = 5000
|
||||
dev_data = training_set_creator.read_training(
|
||||
nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit,
|
||||
kb=el_pipe.kb,
|
||||
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
|
||||
)
|
||||
|
||||
print("Dev testing from file on", len(dev_data), "articles")
|
||||
|
@ -383,9 +379,11 @@ def _measure_baselines(data, kb):
|
|||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity in gold.links:
|
||||
for entity, value in gold.links.items():
|
||||
start, end, gold_kb = entity
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
label = ent.label_
|
||||
|
|
|
@ -1141,7 +1141,7 @@ class EntityLinker(Pipe):
|
|||
|
||||
context_docs = []
|
||||
entity_encodings = []
|
||||
cats = []
|
||||
|
||||
priors = []
|
||||
type_vectors = []
|
||||
|
||||
|
@ -1173,12 +1173,9 @@ class EntityLinker(Pipe):
|
|||
else:
|
||||
priors.append([0])
|
||||
|
||||
cats.append([value])
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)
|
||||
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||
|
||||
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||
|
@ -1186,7 +1183,7 @@ class EntityLinker(Pipe):
|
|||
for i in range(len(entity_encodings))]
|
||||
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
||||
|
||||
loss, d_scores = self.get_loss(scores=pred, golds=cats, docs=docs)
|
||||
loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs)
|
||||
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
||||
|
||||
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
||||
|
@ -1198,9 +1195,17 @@ class EntityLinker(Pipe):
|
|||
return 0
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
d_scores = (scores - golds)
|
||||
cats = []
|
||||
for gold in golds:
|
||||
for entity, value in gold.links.items():
|
||||
cats.append([value])
|
||||
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
assert len(scores) == len(cats)
|
||||
|
||||
d_scores = (scores - cats)
|
||||
loss = (d_scores ** 2).sum()
|
||||
loss = loss / len(golds)
|
||||
loss = loss / len(cats)
|
||||
return loss, d_scores
|
||||
|
||||
def __call__(self, doc):
|
||||
|
|
Loading…
Reference in New Issue
Block a user