mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	60% acc run
This commit is contained in:
		
							parent
							
								
									268a52ead7
								
							
						
					
					
						commit
						9e88763dab
					
				|  | @ -23,7 +23,6 @@ from thinc.misc import LayerNorm as LN | |||
| 
 | ||||
| # from spacy.cli.pretrain import get_cossim_loss | ||||
| from spacy.matcher import PhraseMatcher | ||||
| from spacy.tokens import Doc | ||||
| 
 | ||||
| """ TODO: this code needs to be implemented in pipes.pyx""" | ||||
| 
 | ||||
|  | @ -46,7 +45,7 @@ class EL_Model: | |||
| 
 | ||||
|     DROP = 0.1 | ||||
|     LEARN_RATE = 0.001 | ||||
|     EPOCHS = 10 | ||||
|     EPOCHS = 20 | ||||
|     L2 = 1e-6 | ||||
| 
 | ||||
|     name = "entity_linker" | ||||
|  | @ -211,9 +210,6 @@ class EL_Model: | |||
|         return acc | ||||
| 
 | ||||
|     def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True): | ||||
|         # print() | ||||
|         # print("predicting article") | ||||
| 
 | ||||
|         if avg: | ||||
|             with self.article_encoder.use_params(self.sgd_article.averages) \ | ||||
|                  and self.desc_encoder.use_params(self.sgd_desc.averages)\ | ||||
|  | @ -228,16 +224,10 @@ class EL_Model: | |||
|             doc_encoding = self.article_encoder([article_doc]) | ||||
|             sent_encoding = self.sent_encoder([sent_doc]) | ||||
| 
 | ||||
|         # print("desc_encodings", desc_encodings) | ||||
|         # print("doc_encoding", doc_encoding) | ||||
|         # print("sent_encoding", sent_encoding) | ||||
|         concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])] | ||||
|         # print("concat_encoding", concat_encoding) | ||||
| 
 | ||||
|         cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]])) | ||||
|         # print("cont_encodings", cont_encodings) | ||||
|         context_enc = np.transpose(cont_encodings) | ||||
|         # print("context_enc", context_enc) | ||||
| 
 | ||||
|         highest_sim = -5 | ||||
|         best_i = -1 | ||||
|  | @ -353,11 +343,11 @@ class EL_Model: | |||
|                     sents_list.append(sent) | ||||
|                     descs_list.append(descs[e]) | ||||
|                     targets.append([1]) | ||||
|                 else: | ||||
|                     arts_list.append(art) | ||||
|                     sents_list.append(sent) | ||||
|                     descs_list.append(descs[e]) | ||||
|                     targets.append([-1]) | ||||
|                 # else: | ||||
|                 #    arts_list.append(art) | ||||
|                 #    sents_list.append(sent) | ||||
|                 #    descs_list.append(descs[e]) | ||||
|                 #    targets.append([-1]) | ||||
| 
 | ||||
|         desc_docs = self.nlp.pipe(descs_list) | ||||
|         desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP) | ||||
|  | @ -372,18 +362,17 @@ class EL_Model: | |||
|                             range(len(targets))] | ||||
|         cont_encodings, bp_cont = self.cont_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP) | ||||
| 
 | ||||
|         # print("sent_encodings", type(sent_encodings), sent_encodings) | ||||
|         # print("desc_encodings", type(desc_encodings), desc_encodings) | ||||
|         # print("doc_encodings", type(doc_encodings), doc_encodings) | ||||
|         # print("getting los for", len(arts_list), "entities") | ||||
|         loss, cont_gradient = self.get_loss(cont_encodings, desc_encodings, targets) | ||||
| 
 | ||||
|         loss, gradient = self.get_loss(cont_encodings, desc_encodings, targets) | ||||
|         # loss, desc_gradient = self.get_loss(desc_encodings, cont_encodings, targets) | ||||
|         # cont_gradient = cont_gradient / 2 | ||||
|         # desc_gradient = desc_gradient / 2 | ||||
|         # bp_desc(desc_gradient, sgd=self.sgd_desc) | ||||
| 
 | ||||
|         # print("gradient", gradient) | ||||
|         if self.PRINT_BATCH_LOSS: | ||||
|             print("batch loss", loss) | ||||
| 
 | ||||
|         context_gradient = bp_cont(gradient, sgd=self.sgd_cont) | ||||
|         context_gradient = bp_cont(cont_gradient, sgd=self.sgd_cont) | ||||
| 
 | ||||
|         # gradient : concat (doc+sent) vs. desc | ||||
|         sent_start = self.ARTICLE_WIDTH | ||||
|  | @ -393,9 +382,6 @@ class EL_Model: | |||
|             doc_gradients.append(list(x[0:sent_start])) | ||||
|             sent_gradients.append(list(x[sent_start:])) | ||||
| 
 | ||||
|         # print("doc_gradients", doc_gradients) | ||||
|         # print("sent_gradients", sent_gradients) | ||||
| 
 | ||||
|         bp_doc(doc_gradients, sgd=self.sgd_article) | ||||
|         bp_sent(sent_gradients, sgd=self.sgd_sent) | ||||
| 
 | ||||
|  | @ -426,74 +412,75 @@ class EL_Model: | |||
|                     article_id = f.replace(".txt", "") | ||||
|                     if cnt % 500 == 0 and to_print: | ||||
|                         print(datetime.datetime.now(), "processed", cnt, "files in the training dataset") | ||||
|                     cnt += 1 | ||||
| 
 | ||||
|                     # parse the article text | ||||
|                     with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: | ||||
|                         text = file.read() | ||||
|                         article_doc = self.nlp(text) | ||||
|                         truncated_text = text[0:min(self.DOC_CUTOFF, len(text))] | ||||
|                         text_by_article[article_id] = truncated_text | ||||
|                     try: | ||||
|                         # parse the article text | ||||
|                         with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: | ||||
|                             text = file.read() | ||||
|                             article_doc = self.nlp(text) | ||||
|                             truncated_text = text[0:min(self.DOC_CUTOFF, len(text))] | ||||
|                             text_by_article[article_id] = truncated_text | ||||
| 
 | ||||
|                     # process all positive and negative entities, collect all relevant mentions in this article | ||||
|                     for mention, entity_pos in correct_entries[article_id].items(): | ||||
|                         cluster = article_id + "_" + mention | ||||
|                         descr = id_to_descr.get(entity_pos) | ||||
|                         entities = set() | ||||
|                         if descr: | ||||
|                             entity = "E_" + str(next_entity_nr) + "_" + cluster | ||||
|                             next_entity_nr += 1 | ||||
|                             gold_by_entity[entity] = 1 | ||||
|                             desc_by_entity[entity] = descr | ||||
|                             entities.add(entity) | ||||
|                         # process all positive and negative entities, collect all relevant mentions in this article | ||||
|                         for mention, entity_pos in correct_entries[article_id].items(): | ||||
|                             cluster = article_id + "_" + mention | ||||
|                             descr = id_to_descr.get(entity_pos) | ||||
|                             entities = set() | ||||
|                             if descr: | ||||
|                                 entity = "E_" + str(next_entity_nr) + "_" + cluster | ||||
|                                 next_entity_nr += 1 | ||||
|                                 gold_by_entity[entity] = 1 | ||||
|                                 desc_by_entity[entity] = descr | ||||
|                                 entities.add(entity) | ||||
| 
 | ||||
|                             entity_negs = incorrect_entries[article_id][mention] | ||||
|                             for entity_neg in entity_negs: | ||||
|                                 descr = id_to_descr.get(entity_neg) | ||||
|                                 if descr: | ||||
|                                     entity = "E_" + str(next_entity_nr) + "_" + cluster | ||||
|                                     next_entity_nr += 1 | ||||
|                                     gold_by_entity[entity] = 0 | ||||
|                                     desc_by_entity[entity] = descr | ||||
|                                     entities.add(entity) | ||||
|                                 entity_negs = incorrect_entries[article_id][mention] | ||||
|                                 for entity_neg in entity_negs: | ||||
|                                     descr = id_to_descr.get(entity_neg) | ||||
|                                     if descr: | ||||
|                                         entity = "E_" + str(next_entity_nr) + "_" + cluster | ||||
|                                         next_entity_nr += 1 | ||||
|                                         gold_by_entity[entity] = 0 | ||||
|                                         desc_by_entity[entity] = descr | ||||
|                                         entities.add(entity) | ||||
| 
 | ||||
|                         found_matches = 0 | ||||
|                         if len(entities) > 1: | ||||
|                             entities_by_cluster[cluster] = entities | ||||
|                             found_matches = 0 | ||||
|                             if len(entities) > 1: | ||||
|                                 entities_by_cluster[cluster] = entities | ||||
| 
 | ||||
|                             # find all matches in the doc for the mentions | ||||
|                             # TODO: fix this - doesn't look like all entities are found | ||||
|                             matcher = PhraseMatcher(self.nlp.vocab) | ||||
|                             patterns = list(self.nlp.tokenizer.pipe([mention])) | ||||
|                                 # find all matches in the doc for the mentions | ||||
|                                 # TODO: fix this - doesn't look like all entities are found | ||||
|                                 matcher = PhraseMatcher(self.nlp.vocab) | ||||
|                                 patterns = list(self.nlp.tokenizer.pipe([mention])) | ||||
| 
 | ||||
|                             matcher.add("TerminologyList", None, *patterns) | ||||
|                             matches = matcher(article_doc) | ||||
|                                 matcher.add("TerminologyList", None, *patterns) | ||||
|                                 matches = matcher(article_doc) | ||||
| 
 | ||||
|                                 # store sentences | ||||
|                                 for match_id, start, end in matches: | ||||
|                                     span = article_doc[start:end] | ||||
|                                     if mention == span.text: | ||||
|                                         found_matches += 1 | ||||
|                                         sent_text = span.sent.text | ||||
|                                         sent_nr = sentence_by_text.get(sent_text,  None) | ||||
|                                         if sent_nr is None: | ||||
|                                             sent_nr = "S_" + str(next_sent_nr) + article_id | ||||
|                                             next_sent_nr += 1 | ||||
|                                             text_by_sentence[sent_nr] = sent_text | ||||
|                                             sentence_by_text[sent_text] = sent_nr | ||||
|                                         article_by_cluster[cluster] = article_id | ||||
|                                         sentence_by_cluster[cluster] = sent_nr | ||||
| 
 | ||||
|                             # store sentences | ||||
|                             for match_id, start, end in matches: | ||||
|                                 found_matches += 1 | ||||
|                                 span = article_doc[start:end] | ||||
|                                 assert mention == span.text | ||||
|                                 sent_text = span.sent.text | ||||
|                                 sent_nr = sentence_by_text.get(sent_text,  None) | ||||
|                                 if sent_nr is None: | ||||
|                                     sent_nr = "S_" + str(next_sent_nr) + article_id | ||||
|                                     next_sent_nr += 1 | ||||
|                                     text_by_sentence[sent_nr] = sent_text | ||||
|                                     sentence_by_text[sent_text] = sent_nr | ||||
|                                 article_by_cluster[cluster] = article_id | ||||
|                                 sentence_by_cluster[cluster] = sent_nr | ||||
| 
 | ||||
|                         if found_matches == 0: | ||||
|                             # TODO print("Could not find neg instances or sentence matches for", mention, "in", article_id) | ||||
|                             entities_by_cluster.pop(cluster, None) | ||||
|                             article_by_cluster.pop(cluster, None) | ||||
|                             sentence_by_cluster.pop(cluster, None) | ||||
|                             for entity in entities: | ||||
|                                 gold_by_entity.pop(entity, None) | ||||
|                                 desc_by_entity.pop(entity, None) | ||||
| 
 | ||||
|                             if found_matches == 0: | ||||
|                                 # print("Could not find neg instances or sentence matches for", mention, "in", article_id) | ||||
|                                 entities_by_cluster.pop(cluster, None) | ||||
|                                 article_by_cluster.pop(cluster, None) | ||||
|                                 sentence_by_cluster.pop(cluster, None) | ||||
|                                 for entity in entities: | ||||
|                                     gold_by_entity.pop(entity, None) | ||||
|                                     desc_by_entity.pop(entity, None) | ||||
|                         cnt += 1 | ||||
|                     except: | ||||
|                         print("Problem parsing article", article_id) | ||||
| 
 | ||||
|         if to_print: | ||||
|             print() | ||||
|  |  | |||
|  | @ -111,7 +111,7 @@ if __name__ == "__main__": | |||
|         print("STEP 6: training", datetime.datetime.now()) | ||||
|         my_nlp = spacy.load('en_core_web_md') | ||||
|         trainer = EL_Model(kb=my_kb, nlp=my_nlp) | ||||
|         trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=1000, devlimit=100) | ||||
|         trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10000, devlimit=500) | ||||
|         print() | ||||
| 
 | ||||
|     # STEP 7: apply the EL algorithm on the dev dataset | ||||
|  | @ -120,7 +120,6 @@ if __name__ == "__main__": | |||
|         run_el.run_el_dev(kb=my_kb, nlp=my_nlp, training_dir=TRAINING_DIR, limit=2000) | ||||
|         print() | ||||
| 
 | ||||
| 
 | ||||
|     # TODO coreference resolution | ||||
|     # add_coref() | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user