mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-30 01:43:21 +03:00
code cleanup
This commit is contained in:
parent
cdc589d344
commit
a63d15a142
|
@ -14,7 +14,6 @@ from thinc.neural.util import to_categorical
|
||||||
from thinc.neural.util import get_array_module
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
from ..cli.pretrain import get_cossim_loss
|
|
||||||
from .functions import merge_subtokens
|
from .functions import merge_subtokens
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..syntax.nn_parser cimport Parser
|
from ..syntax.nn_parser cimport Parser
|
||||||
|
@ -1164,7 +1163,6 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
candidates = self.kb.get_candidates(mention)
|
candidates = self.kb.get_candidates(mention)
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
nr_neg = 0
|
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
kb_id = c.entity_
|
kb_id = c.entity_
|
||||||
entity_encoding = c.entity_vector
|
entity_encoding = c.entity_vector
|
||||||
|
@ -1180,21 +1178,20 @@ class EntityLinker(Pipe):
|
||||||
if kb_id == gold_kb:
|
if kb_id == gold_kb:
|
||||||
cats.append([1])
|
cats.append([1])
|
||||||
else:
|
else:
|
||||||
nr_neg += 1
|
|
||||||
cats.append([0])
|
cats.append([0])
|
||||||
|
|
||||||
if len(entity_encodings) > 0:
|
if len(entity_encodings) > 0:
|
||||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
||||||
|
|
||||||
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||||
|
|
||||||
|
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||||
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
|
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
|
||||||
for i in range(len(entity_encodings))]
|
for i in range(len(entity_encodings))]
|
||||||
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
||||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
|
||||||
|
|
||||||
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
|
loss, d_scores = self.get_loss(scores=pred, golds=cats, docs=docs)
|
||||||
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
||||||
|
|
||||||
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
||||||
|
@ -1205,18 +1202,12 @@ class EntityLinker(Pipe):
|
||||||
return loss
|
return loss
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_loss(self, docs, golds, prediction):
|
def get_loss(self, docs, golds, scores):
|
||||||
d_scores = (prediction - golds)
|
d_scores = (scores - golds)
|
||||||
loss = (d_scores ** 2).sum()
|
loss = (d_scores ** 2).sum()
|
||||||
loss = loss / len(golds)
|
loss = loss / len(golds)
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
def get_loss_old(self, docs, golds, scores):
|
|
||||||
# this loss function assumes we're only using positive examples
|
|
||||||
loss, gradients = get_cossim_loss(yh=scores, y=golds)
|
|
||||||
loss = loss / len(golds)
|
|
||||||
return loss, gradients
|
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
entities, kb_ids = self.predict([doc])
|
entities, kb_ids = self.predict([doc])
|
||||||
self.set_annotations([doc], entities, kb_ids)
|
self.set_annotations([doc], entities, kb_ids)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user