mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
training loop in proper pipe format
This commit is contained in:
parent
0486ccabfd
commit
7de1ee69b8
|
@ -126,7 +126,7 @@ if __name__ == "__main__":
|
||||||
id_to_descr=id_to_descr,
|
id_to_descr=id_to_descr,
|
||||||
doc_cutoff=DOC_CHAR_CUTOFF,
|
doc_cutoff=DOC_CHAR_CUTOFF,
|
||||||
dev=False,
|
dev=False,
|
||||||
limit=10,
|
limit=100,
|
||||||
to_print=False)
|
to_print=False)
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
|
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
|
||||||
|
@ -137,6 +137,8 @@ if __name__ == "__main__":
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
|
|
||||||
for itn in range(EPOCHS):
|
for itn in range(EPOCHS):
|
||||||
|
print()
|
||||||
|
print("EPOCH", itn)
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_data)
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
||||||
|
@ -150,15 +152,6 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
print("Losses", losses)
|
print("Losses", losses)
|
||||||
|
|
||||||
### BELOW CODE IS DEPRECATED ###
|
|
||||||
|
|
||||||
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx
|
|
||||||
if run_el_training:
|
|
||||||
print("STEP 6: training", datetime.datetime.now())
|
|
||||||
trainer = EL_Model(kb=my_kb, nlp=nlp)
|
|
||||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10000, devlimit=500)
|
|
||||||
print()
|
|
||||||
|
|
||||||
# STEP 7: apply the EL algorithm on the dev dataset (TODO: overlaps with code from run_el_training ?)
|
# STEP 7: apply the EL algorithm on the dev dataset (TODO: overlaps with code from run_el_training ?)
|
||||||
if apply_to_dev:
|
if apply_to_dev:
|
||||||
run_el.run_el_dev(kb=my_kb, nlp=nlp, training_dir=TRAINING_DIR, limit=2000)
|
run_el.run_el_dev(kb=my_kb, nlp=nlp, training_dir=TRAINING_DIR, limit=2000)
|
||||||
|
|
|
@ -1125,51 +1125,59 @@ class EntityLinker(Pipe):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
golds = [golds]
|
golds = [golds]
|
||||||
|
|
||||||
|
article_docs = list()
|
||||||
|
sentence_docs = list()
|
||||||
|
entity_encodings = list()
|
||||||
|
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
print("doc", doc)
|
|
||||||
for entity in gold.links:
|
for entity in gold.links:
|
||||||
start, end, gold_kb = entity
|
start, end, gold_kb = entity
|
||||||
print("entity", entity)
|
mention = doc[start:end]
|
||||||
mention = doc[start:end].text
|
sentence = mention.sent
|
||||||
print("mention", mention)
|
|
||||||
candidates = self.kb.get_candidates(mention)
|
candidates = self.kb.get_candidates(mention.text)
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
prior_prob = c.prior_prob
|
|
||||||
kb_id = c.entity_
|
kb_id = c.entity_
|
||||||
print("candidate", kb_id, prior_prob)
|
# TODO: currently only training on the positive instances
|
||||||
entity_encoding = c.entity_vector
|
if kb_id == gold_kb:
|
||||||
print()
|
prior_prob = c.prior_prob
|
||||||
|
entity_encoding = c.entity_vector
|
||||||
|
|
||||||
print()
|
entity_encodings.append(entity_encoding)
|
||||||
|
article_docs.append(doc)
|
||||||
|
sentence_docs.append(sentence.as_doc())
|
||||||
|
|
||||||
# entity_encodings = None #TODO
|
if len(entity_encodings) > 0:
|
||||||
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||||
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||||
#
|
|
||||||
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
||||||
# range(len(article_docs))]
|
range(len(article_docs))]
|
||||||
# mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
|
||||||
#
|
|
||||||
# loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
|
||||||
#
|
|
||||||
# mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont)
|
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
||||||
#
|
|
||||||
# # gradient : concat (doc+sent) vs. desc
|
mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention)
|
||||||
# sent_start = self.article_encoder.nO
|
|
||||||
# sent_gradients = list()
|
# gradient : concat (doc+sent) vs. desc
|
||||||
# doc_gradients = list()
|
sent_start = self.article_encoder.nO
|
||||||
# for x in mention_gradient:
|
sent_gradients = list()
|
||||||
# doc_gradients.append(list(x[0:sent_start]))
|
doc_gradients = list()
|
||||||
# sent_gradients.append(list(x[sent_start:]))
|
for x in mention_gradient:
|
||||||
#
|
doc_gradients.append(list(x[0:sent_start]))
|
||||||
# bp_doc(doc_gradients, sgd=self.sgd_article)
|
sent_gradients.append(list(x[sent_start:]))
|
||||||
# bp_sent(sent_gradients, sgd=self.sgd_sent)
|
|
||||||
#
|
bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||||
# if losses is not None:
|
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||||
# losses.setdefault(self.name, 0.0)
|
|
||||||
# losses[self.name] += loss
|
if losses is not None:
|
||||||
# return loss
|
losses.setdefault(self.name, 0.0)
|
||||||
return None
|
losses[self.name] += loss
|
||||||
|
return loss
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
def get_loss(self, docs, golds, scores):
|
||||||
loss, gradients = get_cossim_loss(scores, golds)
|
loss, gradients = get_cossim_loss(scores, golds)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user