mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
training loop in proper pipe format
This commit is contained in:
parent
0486ccabfd
commit
7de1ee69b8
|
@ -126,7 +126,7 @@ if __name__ == "__main__":
|
|||
id_to_descr=id_to_descr,
|
||||
doc_cutoff=DOC_CHAR_CUTOFF,
|
||||
dev=False,
|
||||
limit=10,
|
||||
limit=100,
|
||||
to_print=False)
|
||||
|
||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
|
||||
|
@ -137,6 +137,8 @@ if __name__ == "__main__":
|
|||
nlp.begin_training()
|
||||
|
||||
for itn in range(EPOCHS):
|
||||
print()
|
||||
print("EPOCH", itn)
|
||||
random.shuffle(train_data)
|
||||
losses = {}
|
||||
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
||||
|
@ -150,15 +152,6 @@ if __name__ == "__main__":
|
|||
)
|
||||
print("Losses", losses)
|
||||
|
||||
### BELOW CODE IS DEPRECATED ###
|
||||
|
||||
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx
|
||||
if run_el_training:
|
||||
print("STEP 6: training", datetime.datetime.now())
|
||||
trainer = EL_Model(kb=my_kb, nlp=nlp)
|
||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10000, devlimit=500)
|
||||
print()
|
||||
|
||||
# STEP 7: apply the EL algorithm on the dev dataset (TODO: overlaps with code from run_el_training ?)
|
||||
if apply_to_dev:
|
||||
run_el.run_el_dev(kb=my_kb, nlp=nlp, training_dir=TRAINING_DIR, limit=2000)
|
||||
|
|
|
@ -1125,51 +1125,59 @@ class EntityLinker(Pipe):
|
|||
docs = [docs]
|
||||
golds = [golds]
|
||||
|
||||
article_docs = list()
|
||||
sentence_docs = list()
|
||||
entity_encodings = list()
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
print("doc", doc)
|
||||
for entity in gold.links:
|
||||
start, end, gold_kb = entity
|
||||
print("entity", entity)
|
||||
mention = doc[start:end].text
|
||||
print("mention", mention)
|
||||
candidates = self.kb.get_candidates(mention)
|
||||
mention = doc[start:end]
|
||||
sentence = mention.sent
|
||||
|
||||
candidates = self.kb.get_candidates(mention.text)
|
||||
for c in candidates:
|
||||
prior_prob = c.prior_prob
|
||||
kb_id = c.entity_
|
||||
print("candidate", kb_id, prior_prob)
|
||||
# TODO: currently only training on the positive instances
|
||||
if kb_id == gold_kb:
|
||||
prior_prob = c.prior_prob
|
||||
entity_encoding = c.entity_vector
|
||||
print()
|
||||
|
||||
print()
|
||||
entity_encodings.append(entity_encoding)
|
||||
article_docs.append(doc)
|
||||
sentence_docs.append(sentence.as_doc())
|
||||
|
||||
# entity_encodings = None #TODO
|
||||
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||
#
|
||||
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
||||
# range(len(article_docs))]
|
||||
# mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
||||
#
|
||||
# loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
||||
#
|
||||
# mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont)
|
||||
#
|
||||
# # gradient : concat (doc+sent) vs. desc
|
||||
# sent_start = self.article_encoder.nO
|
||||
# sent_gradients = list()
|
||||
# doc_gradients = list()
|
||||
# for x in mention_gradient:
|
||||
# doc_gradients.append(list(x[0:sent_start]))
|
||||
# sent_gradients.append(list(x[sent_start:]))
|
||||
#
|
||||
# bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
# bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
#
|
||||
# if losses is not None:
|
||||
# losses.setdefault(self.name, 0.0)
|
||||
# losses[self.name] += loss
|
||||
# return loss
|
||||
return None
|
||||
if len(entity_encodings) > 0:
|
||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||
sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||
|
||||
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
||||
range(len(article_docs))]
|
||||
mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
|
||||
|
||||
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
|
||||
|
||||
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
||||
|
||||
mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention)
|
||||
|
||||
# gradient : concat (doc+sent) vs. desc
|
||||
sent_start = self.article_encoder.nO
|
||||
sent_gradients = list()
|
||||
doc_gradients = list()
|
||||
for x in mention_gradient:
|
||||
doc_gradients.append(list(x[0:sent_start]))
|
||||
sent_gradients.append(list(x[sent_start:]))
|
||||
|
||||
bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += loss
|
||||
return loss
|
||||
|
||||
return 0
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
loss, gradients = get_cossim_loss(scores, golds)
|
||||
|
|
Loading…
Reference in New Issue
Block a user