mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-04 11:53:09 +03:00
use original gold object in get_loss function
This commit is contained in:
parent
ec55d2fccd
commit
e1213eaf6a
|
@ -295,11 +295,7 @@ def run_pipeline():
|
||||||
|
|
||||||
dev_limit = 5000
|
dev_limit = 5000
|
||||||
dev_data = training_set_creator.read_training(
|
dev_data = training_set_creator.read_training(
|
||||||
nlp=nlp_2,
|
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
|
||||||
training_dir=TRAINING_DIR,
|
|
||||||
dev=True,
|
|
||||||
limit=dev_limit,
|
|
||||||
kb=el_pipe.kb,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Dev testing from file on", len(dev_data), "articles")
|
print("Dev testing from file on", len(dev_data), "articles")
|
||||||
|
@ -383,8 +379,10 @@ def _measure_baselines(data, kb):
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
try:
|
try:
|
||||||
correct_entries_per_article = dict()
|
correct_entries_per_article = dict()
|
||||||
for entity in gold.links:
|
for entity, value in gold.links.items():
|
||||||
start, end, gold_kb = entity
|
start, end, gold_kb = entity
|
||||||
|
# only evaluating on positive examples
|
||||||
|
if value:
|
||||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
|
|
|
@ -1141,7 +1141,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
context_docs = []
|
context_docs = []
|
||||||
entity_encodings = []
|
entity_encodings = []
|
||||||
cats = []
|
|
||||||
priors = []
|
priors = []
|
||||||
type_vectors = []
|
type_vectors = []
|
||||||
|
|
||||||
|
@ -1173,12 +1173,9 @@ class EntityLinker(Pipe):
|
||||||
else:
|
else:
|
||||||
priors.append([0])
|
priors.append([0])
|
||||||
|
|
||||||
cats.append([value])
|
|
||||||
|
|
||||||
if len(entity_encodings) > 0:
|
if len(entity_encodings) > 0:
|
||||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
assert len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)
|
||||||
|
|
||||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
|
||||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||||
|
|
||||||
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||||
|
@ -1186,7 +1183,7 @@ class EntityLinker(Pipe):
|
||||||
for i in range(len(entity_encodings))]
|
for i in range(len(entity_encodings))]
|
||||||
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
||||||
|
|
||||||
loss, d_scores = self.get_loss(scores=pred, golds=cats, docs=docs)
|
loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs)
|
||||||
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
||||||
|
|
||||||
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
||||||
|
@ -1198,9 +1195,17 @@ class EntityLinker(Pipe):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
def get_loss(self, docs, golds, scores):
|
||||||
d_scores = (scores - golds)
|
cats = []
|
||||||
|
for gold in golds:
|
||||||
|
for entity, value in gold.links.items():
|
||||||
|
cats.append([value])
|
||||||
|
|
||||||
|
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||||
|
assert len(scores) == len(cats)
|
||||||
|
|
||||||
|
d_scores = (scores - cats)
|
||||||
loss = (d_scores ** 2).sum()
|
loss = (d_scores ** 2).sum()
|
||||||
loss = loss / len(golds)
|
loss = loss / len(cats)
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user