mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 13:41:21 +03:00 
			
		
		
		
	grouping clusters of instances per doc+mention
This commit is contained in:
		
							parent
							
								
									c6ca8649d7
								
							
						
					
					
						commit
						9d089c0410
					
				|  | @ -7,7 +7,7 @@ from os import listdir | |||
| 
 | ||||
| from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator | ||||
| 
 | ||||
| from spacy._ml import SpacyVectors, create_default_optimizer, zero_init | ||||
| from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, cosine | ||||
| 
 | ||||
| from thinc.api import chain | ||||
| from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu | ||||
|  | @ -33,14 +33,12 @@ class EL_Model(): | |||
|         self.article_encoder = self._simple_encoder(width=300) | ||||
| 
 | ||||
|     def train_model(self, training_dir, entity_descr_output, limit=None, to_print=True): | ||||
|         instances, gold_vectors, entity_descriptions, doc_by_article = self._get_training_data(training_dir, | ||||
|         instances, pos_entities, neg_entities, doc_by_article = self._get_training_data(training_dir, | ||||
|                                                                                                entity_descr_output, | ||||
|                                                                                                limit, to_print) | ||||
| 
 | ||||
|         if to_print: | ||||
|             print("Training on", len(gold_vectors), "instances") | ||||
|             print(" - pos:", len([x for x in gold_vectors if x]), "instances") | ||||
|             print(" - pos:", len([x for x in gold_vectors if not x]), "instances") | ||||
|             print("Training on", len(instances), "instance clusters") | ||||
|             print() | ||||
| 
 | ||||
|         self.sgd_entity = self.begin_training(self.entity_encoder) | ||||
|  | @ -48,11 +46,20 @@ class EL_Model(): | |||
| 
 | ||||
|         losses = {} | ||||
| 
 | ||||
|         for inst, label, entity_descr in zip(instances, gold_vectors, entity_descriptions): | ||||
|             article = inst.split(sep="_")[0] | ||||
|             entity_id = inst.split(sep="_")[1] | ||||
|             article_doc = doc_by_article[article] | ||||
|             self.update(article_doc, entity_descr, label, losses=losses) | ||||
|         for inst_cluster in instances: | ||||
|             pos_ex = pos_entities.get(inst_cluster) | ||||
|             neg_exs = neg_entities.get(inst_cluster, []) | ||||
| 
 | ||||
|             if pos_ex and neg_exs: | ||||
|                 article = inst_cluster.split(sep="_")[0] | ||||
|                 entity_id = inst_cluster.split(sep="_")[1] | ||||
|                 article_doc = doc_by_article[article] | ||||
|                 self.update(article_doc, pos_ex, neg_exs, losses=losses) | ||||
|             # TODO | ||||
|             # elif not pos_ex: | ||||
|                 # print("Weird. Couldn't find pos example for",  inst_cluster) | ||||
|             # elif not neg_exs: | ||||
|                 # print("Weird. Couldn't find neg examples for",  inst_cluster) | ||||
| 
 | ||||
|     def _simple_encoder(self, width): | ||||
|         with Model.define_operators({">>": chain}): | ||||
|  | @ -69,22 +76,29 @@ class EL_Model(): | |||
|         sgd = create_default_optimizer(model.ops) | ||||
|         return sgd | ||||
| 
 | ||||
|     def update(self, article_doc, entity_descr, label, drop=0., losses=None): | ||||
|         entity_encoding, entity_bp = self.entity_encoder.begin_update([entity_descr], drop=drop) | ||||
|     def update(self, article_doc, true_entity, false_entities, drop=0., losses=None): | ||||
|         doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop) | ||||
| 
 | ||||
|         true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop) | ||||
|         # true_similarity = cosine(true_entity_encoding, doc_encoding) | ||||
|         # print("true_similarity", true_similarity) | ||||
| 
 | ||||
|         # for false_entity in false_entities: | ||||
|             # false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop) | ||||
|             # false_similarity = cosine(false_entity_encoding, doc_encoding) | ||||
|             # print("false_similarity", false_similarity) | ||||
| 
 | ||||
|         # print("entity/article output dim", len(entity_encoding[0]), len(doc_encoding[0])) | ||||
| 
 | ||||
|         mse, diffs = self._calculate_similarity(entity_encoding, doc_encoding) | ||||
|         mse, diffs = self._calculate_similarity(true_entity_encoding, doc_encoding) | ||||
| 
 | ||||
|         # print() | ||||
| 
 | ||||
|         # TODO: proper backpropagation taking ranking of elements into account ? | ||||
|         # TODO backpropagation also for negative examples | ||||
|         if label: | ||||
|             entity_bp(diffs, sgd=self.sgd_entity) | ||||
|             article_bp(diffs, sgd=self.sgd_article) | ||||
|             print(mse) | ||||
|         true_entity_bp(diffs, sgd=self.sgd_entity) | ||||
|         article_bp(diffs, sgd=self.sgd_article) | ||||
|         print(mse) | ||||
| 
 | ||||
| 
 | ||||
|     # TODO delete ? | ||||
|  | @ -115,7 +129,7 @@ class EL_Model(): | |||
|             raise ValueError("To calculate similarity, both vectors should be of equal length") | ||||
| 
 | ||||
|         diffs = (vector2 - vector1) | ||||
|         error_sum = (diffs ** 2).sum(axis=1) | ||||
|         error_sum = (diffs ** 2).sum() | ||||
|         mean_square_error = error_sum / len(vector1) | ||||
|         return float(mean_square_error), diffs | ||||
| 
 | ||||
|  | @ -130,10 +144,10 @@ class EL_Model(): | |||
|                                                                                          collect_incorrect=True) | ||||
| 
 | ||||
|         instances = list() | ||||
|         entity_descriptions = list() | ||||
|         local_vectors = list()   # TODO: local vectors | ||||
|         gold_vectors = list() | ||||
|         doc_by_article = dict() | ||||
|         pos_entities = dict() | ||||
|         neg_entities = dict() | ||||
| 
 | ||||
|         cnt = 0 | ||||
|         for f in listdir(training_dir): | ||||
|  | @ -149,25 +163,24 @@ class EL_Model(): | |||
|                             doc = self.nlp(text) | ||||
|                             doc_by_article[article_id] = doc | ||||
| 
 | ||||
|                     for mention_pos, entity_pos in correct_entries[article_id].items(): | ||||
|                     for mention, entity_pos in correct_entries[article_id].items(): | ||||
|                         descr = id_to_descr.get(entity_pos) | ||||
|                         if descr: | ||||
|                             instances.append(article_id + "_" + entity_pos) | ||||
|                             doc = self.nlp(descr) | ||||
|                             entity_descriptions.append(doc) | ||||
|                             gold_vectors.append(True) | ||||
|                             instances.append(article_id + "_" + mention) | ||||
|                             doc_descr = self.nlp(descr) | ||||
|                             pos_entities[article_id + "_" + mention] = doc_descr | ||||
| 
 | ||||
|                     for mention_neg, entity_negs in incorrect_entries[article_id].items(): | ||||
|                     for mention, entity_negs in incorrect_entries[article_id].items(): | ||||
|                         for entity_neg in entity_negs: | ||||
|                             descr = id_to_descr.get(entity_neg) | ||||
|                             if descr: | ||||
|                                 instances.append(article_id + "_" + entity_neg) | ||||
|                                 doc = self.nlp(descr) | ||||
|                                 entity_descriptions.append(doc) | ||||
|                                 gold_vectors.append(False) | ||||
|                                 doc_descr = self.nlp(descr) | ||||
|                                 descr_list = neg_entities.get(article_id + "_" + mention, []) | ||||
|                                 descr_list.append(doc_descr) | ||||
|                                 neg_entities[article_id + "_" + mention] = descr_list | ||||
| 
 | ||||
|         if to_print: | ||||
|             print() | ||||
|             print("Processed", cnt, "dev articles") | ||||
|             print() | ||||
|         return instances, gold_vectors, entity_descriptions, doc_by_article | ||||
|         return instances, pos_entities, neg_entities, doc_by_article | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user