mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
60% acc run
This commit is contained in:
parent
268a52ead7
commit
9e88763dab
|
@ -23,7 +23,6 @@ from thinc.misc import LayerNorm as LN
|
|||
|
||||
# from spacy.cli.pretrain import get_cossim_loss
|
||||
from spacy.matcher import PhraseMatcher
|
||||
from spacy.tokens import Doc
|
||||
|
||||
""" TODO: this code needs to be implemented in pipes.pyx"""
|
||||
|
||||
|
@ -46,7 +45,7 @@ class EL_Model:
|
|||
|
||||
DROP = 0.1
|
||||
LEARN_RATE = 0.001
|
||||
EPOCHS = 10
|
||||
EPOCHS = 20
|
||||
L2 = 1e-6
|
||||
|
||||
name = "entity_linker"
|
||||
|
@ -211,9 +210,6 @@ class EL_Model:
|
|||
return acc
|
||||
|
||||
def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True):
|
||||
# print()
|
||||
# print("predicting article")
|
||||
|
||||
if avg:
|
||||
with self.article_encoder.use_params(self.sgd_article.averages) \
|
||||
and self.desc_encoder.use_params(self.sgd_desc.averages)\
|
||||
|
@ -228,16 +224,10 @@ class EL_Model:
|
|||
doc_encoding = self.article_encoder([article_doc])
|
||||
sent_encoding = self.sent_encoder([sent_doc])
|
||||
|
||||
# print("desc_encodings", desc_encodings)
|
||||
# print("doc_encoding", doc_encoding)
|
||||
# print("sent_encoding", sent_encoding)
|
||||
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
||||
# print("concat_encoding", concat_encoding)
|
||||
|
||||
cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]]))
|
||||
# print("cont_encodings", cont_encodings)
|
||||
context_enc = np.transpose(cont_encodings)
|
||||
# print("context_enc", context_enc)
|
||||
|
||||
highest_sim = -5
|
||||
best_i = -1
|
||||
|
@ -353,11 +343,11 @@ class EL_Model:
|
|||
sents_list.append(sent)
|
||||
descs_list.append(descs[e])
|
||||
targets.append([1])
|
||||
else:
|
||||
arts_list.append(art)
|
||||
sents_list.append(sent)
|
||||
descs_list.append(descs[e])
|
||||
targets.append([-1])
|
||||
# else:
|
||||
# arts_list.append(art)
|
||||
# sents_list.append(sent)
|
||||
# descs_list.append(descs[e])
|
||||
# targets.append([-1])
|
||||
|
||||
desc_docs = self.nlp.pipe(descs_list)
|
||||
desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP)
|
||||
|
@ -372,18 +362,17 @@ class EL_Model:
|
|||
range(len(targets))]
|
||||
cont_encodings, bp_cont = self.cont_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
||||
|
||||
# print("sent_encodings", type(sent_encodings), sent_encodings)
|
||||
# print("desc_encodings", type(desc_encodings), desc_encodings)
|
||||
# print("doc_encodings", type(doc_encodings), doc_encodings)
|
||||
# print("getting los for", len(arts_list), "entities")
|
||||
loss, cont_gradient = self.get_loss(cont_encodings, desc_encodings, targets)
|
||||
|
||||
loss, gradient = self.get_loss(cont_encodings, desc_encodings, targets)
|
||||
# loss, desc_gradient = self.get_loss(desc_encodings, cont_encodings, targets)
|
||||
# cont_gradient = cont_gradient / 2
|
||||
# desc_gradient = desc_gradient / 2
|
||||
# bp_desc(desc_gradient, sgd=self.sgd_desc)
|
||||
|
||||
# print("gradient", gradient)
|
||||
if self.PRINT_BATCH_LOSS:
|
||||
print("batch loss", loss)
|
||||
|
||||
context_gradient = bp_cont(gradient, sgd=self.sgd_cont)
|
||||
context_gradient = bp_cont(cont_gradient, sgd=self.sgd_cont)
|
||||
|
||||
# gradient : concat (doc+sent) vs. desc
|
||||
sent_start = self.ARTICLE_WIDTH
|
||||
|
@ -393,9 +382,6 @@ class EL_Model:
|
|||
doc_gradients.append(list(x[0:sent_start]))
|
||||
sent_gradients.append(list(x[sent_start:]))
|
||||
|
||||
# print("doc_gradients", doc_gradients)
|
||||
# print("sent_gradients", sent_gradients)
|
||||
|
||||
bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
|
||||
|
@ -426,74 +412,75 @@ class EL_Model:
|
|||
article_id = f.replace(".txt", "")
|
||||
if cnt % 500 == 0 and to_print:
|
||||
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
|
||||
cnt += 1
|
||||
|
||||
# parse the article text
|
||||
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||
text = file.read()
|
||||
article_doc = self.nlp(text)
|
||||
truncated_text = text[0:min(self.DOC_CUTOFF, len(text))]
|
||||
text_by_article[article_id] = truncated_text
|
||||
try:
|
||||
# parse the article text
|
||||
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||
text = file.read()
|
||||
article_doc = self.nlp(text)
|
||||
truncated_text = text[0:min(self.DOC_CUTOFF, len(text))]
|
||||
text_by_article[article_id] = truncated_text
|
||||
|
||||
# process all positive and negative entities, collect all relevant mentions in this article
|
||||
for mention, entity_pos in correct_entries[article_id].items():
|
||||
cluster = article_id + "_" + mention
|
||||
descr = id_to_descr.get(entity_pos)
|
||||
entities = set()
|
||||
if descr:
|
||||
entity = "E_" + str(next_entity_nr) + "_" + cluster
|
||||
next_entity_nr += 1
|
||||
gold_by_entity[entity] = 1
|
||||
desc_by_entity[entity] = descr
|
||||
entities.add(entity)
|
||||
# process all positive and negative entities, collect all relevant mentions in this article
|
||||
for mention, entity_pos in correct_entries[article_id].items():
|
||||
cluster = article_id + "_" + mention
|
||||
descr = id_to_descr.get(entity_pos)
|
||||
entities = set()
|
||||
if descr:
|
||||
entity = "E_" + str(next_entity_nr) + "_" + cluster
|
||||
next_entity_nr += 1
|
||||
gold_by_entity[entity] = 1
|
||||
desc_by_entity[entity] = descr
|
||||
entities.add(entity)
|
||||
|
||||
entity_negs = incorrect_entries[article_id][mention]
|
||||
for entity_neg in entity_negs:
|
||||
descr = id_to_descr.get(entity_neg)
|
||||
if descr:
|
||||
entity = "E_" + str(next_entity_nr) + "_" + cluster
|
||||
next_entity_nr += 1
|
||||
gold_by_entity[entity] = 0
|
||||
desc_by_entity[entity] = descr
|
||||
entities.add(entity)
|
||||
entity_negs = incorrect_entries[article_id][mention]
|
||||
for entity_neg in entity_negs:
|
||||
descr = id_to_descr.get(entity_neg)
|
||||
if descr:
|
||||
entity = "E_" + str(next_entity_nr) + "_" + cluster
|
||||
next_entity_nr += 1
|
||||
gold_by_entity[entity] = 0
|
||||
desc_by_entity[entity] = descr
|
||||
entities.add(entity)
|
||||
|
||||
found_matches = 0
|
||||
if len(entities) > 1:
|
||||
entities_by_cluster[cluster] = entities
|
||||
found_matches = 0
|
||||
if len(entities) > 1:
|
||||
entities_by_cluster[cluster] = entities
|
||||
|
||||
# find all matches in the doc for the mentions
|
||||
# TODO: fix this - doesn't look like all entities are found
|
||||
matcher = PhraseMatcher(self.nlp.vocab)
|
||||
patterns = list(self.nlp.tokenizer.pipe([mention]))
|
||||
# find all matches in the doc for the mentions
|
||||
# TODO: fix this - doesn't look like all entities are found
|
||||
matcher = PhraseMatcher(self.nlp.vocab)
|
||||
patterns = list(self.nlp.tokenizer.pipe([mention]))
|
||||
|
||||
matcher.add("TerminologyList", None, *patterns)
|
||||
matches = matcher(article_doc)
|
||||
matcher.add("TerminologyList", None, *patterns)
|
||||
matches = matcher(article_doc)
|
||||
|
||||
# store sentences
|
||||
for match_id, start, end in matches:
|
||||
span = article_doc[start:end]
|
||||
if mention == span.text:
|
||||
found_matches += 1
|
||||
sent_text = span.sent.text
|
||||
sent_nr = sentence_by_text.get(sent_text, None)
|
||||
if sent_nr is None:
|
||||
sent_nr = "S_" + str(next_sent_nr) + article_id
|
||||
next_sent_nr += 1
|
||||
text_by_sentence[sent_nr] = sent_text
|
||||
sentence_by_text[sent_text] = sent_nr
|
||||
article_by_cluster[cluster] = article_id
|
||||
sentence_by_cluster[cluster] = sent_nr
|
||||
|
||||
# store sentences
|
||||
for match_id, start, end in matches:
|
||||
found_matches += 1
|
||||
span = article_doc[start:end]
|
||||
assert mention == span.text
|
||||
sent_text = span.sent.text
|
||||
sent_nr = sentence_by_text.get(sent_text, None)
|
||||
if sent_nr is None:
|
||||
sent_nr = "S_" + str(next_sent_nr) + article_id
|
||||
next_sent_nr += 1
|
||||
text_by_sentence[sent_nr] = sent_text
|
||||
sentence_by_text[sent_text] = sent_nr
|
||||
article_by_cluster[cluster] = article_id
|
||||
sentence_by_cluster[cluster] = sent_nr
|
||||
|
||||
if found_matches == 0:
|
||||
# TODO print("Could not find neg instances or sentence matches for", mention, "in", article_id)
|
||||
entities_by_cluster.pop(cluster, None)
|
||||
article_by_cluster.pop(cluster, None)
|
||||
sentence_by_cluster.pop(cluster, None)
|
||||
for entity in entities:
|
||||
gold_by_entity.pop(entity, None)
|
||||
desc_by_entity.pop(entity, None)
|
||||
|
||||
if found_matches == 0:
|
||||
# print("Could not find neg instances or sentence matches for", mention, "in", article_id)
|
||||
entities_by_cluster.pop(cluster, None)
|
||||
article_by_cluster.pop(cluster, None)
|
||||
sentence_by_cluster.pop(cluster, None)
|
||||
for entity in entities:
|
||||
gold_by_entity.pop(entity, None)
|
||||
desc_by_entity.pop(entity, None)
|
||||
cnt += 1
|
||||
except:
|
||||
print("Problem parsing article", article_id)
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
|
|
|
@ -111,7 +111,7 @@ if __name__ == "__main__":
|
|||
print("STEP 6: training", datetime.datetime.now())
|
||||
my_nlp = spacy.load('en_core_web_md')
|
||||
trainer = EL_Model(kb=my_kb, nlp=my_nlp)
|
||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=1000, devlimit=100)
|
||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10000, devlimit=500)
|
||||
print()
|
||||
|
||||
# STEP 7: apply the EL algorithm on the dev dataset
|
||||
|
@ -120,7 +120,6 @@ if __name__ == "__main__":
|
|||
run_el.run_el_dev(kb=my_kb, nlp=my_nlp, training_dir=TRAINING_DIR, limit=2000)
|
||||
print()
|
||||
|
||||
|
||||
# TODO coreference resolution
|
||||
# add_coref()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user