mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	update per entity
This commit is contained in:
		
							parent
							
								
									eb08bdb11f
								
							
						
					
					
						commit
						1a16490d20
					
				|  | @ -154,7 +154,7 @@ class EL_Model: | |||
|         if self.PRINT_F: | ||||
|             print("p/r/F", print_string, round(p, 1), round(r, 1), round(f, 1)) | ||||
| 
 | ||||
|         loss, d_scores = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds)) | ||||
|         loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds)) | ||||
|         if self.PRINT_LOSS: | ||||
|             print("loss", print_string, round(loss, 5)) | ||||
| 
 | ||||
|  | @ -235,62 +235,58 @@ class EL_Model: | |||
|     @staticmethod | ||||
|     def get_loss(predictions, golds): | ||||
|         d_scores = (predictions - golds) | ||||
| 
 | ||||
|         loss = (d_scores ** 2).sum() | ||||
|         loss = (d_scores ** 2).mean() | ||||
|         return loss, d_scores | ||||
| 
 | ||||
|     # TODO: multiple docs/articles | ||||
|     def update(self, article_text, entities, golds, apply_threshold=True): | ||||
|         article_doc = self.nlp(article_text) | ||||
|         doc_encodings, bp_doc = self.article_encoder.begin_update([article_doc], drop=self.DROP) | ||||
|         doc_encoding = doc_encodings[0] | ||||
|         # entity_docs = list(self.nlp.pipe(entities)) | ||||
| 
 | ||||
|         entity_docs = list(self.nlp.pipe(entities)) | ||||
|         # print("entity_docs", type(entity_docs)) | ||||
|         for entity, gold in zip(entities, golds): | ||||
|             doc_encodings, bp_doc = self.article_encoder.begin_update([article_doc], drop=self.DROP) | ||||
|             doc_encoding = doc_encodings[0] | ||||
| 
 | ||||
|         entity_encodings, bp_entity = self.entity_encoder.begin_update(entity_docs, drop=self.DROP) | ||||
|         # print("entity_encodings", len(entity_encodings), entity_encodings) | ||||
|             entity_doc = self.nlp(entity) | ||||
|             # print("entity_docs", type(entity_doc)) | ||||
| 
 | ||||
|         concat_encodings = [list(entity_encodings[i]) + list(doc_encoding) for i in range(len(entities))] | ||||
|         # print("concat_encodings", len(concat_encodings), concat_encodings) | ||||
|             entity_encodings, bp_entity = self.entity_encoder.begin_update([entity_doc], drop=self.DROP) | ||||
|             entity_encoding = entity_encodings[0] | ||||
|             # print("entity_encoding", len(entity_encoding), entity_encoding) | ||||
| 
 | ||||
|         predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP) | ||||
|         predictions = self.model.ops.flatten(predictions) | ||||
|             concat_encodings = [list(entity_encoding) + list(doc_encoding)]  #  for i in range(len(entities)) | ||||
|             # print("concat_encodings", len(concat_encodings), concat_encodings) | ||||
| 
 | ||||
|         # print("predictions", predictions) | ||||
|         golds = self.model.ops.asarray(golds) | ||||
|         # print("golds", golds) | ||||
|             prediction, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP) | ||||
|             # predictions = self.model.ops.flatten(predictions) | ||||
| 
 | ||||
|         loss, d_scores = self.get_loss(predictions, golds) | ||||
|             # print("prediction", prediction) | ||||
|             # golds = self.model.ops.asarray(golds) | ||||
|             # print("gold", gold) | ||||
| 
 | ||||
|         if self.PRINT_LOSS and self.PRINT_TRAIN: | ||||
|             print("loss train", round(loss, 5)) | ||||
|             loss, gradient = self.get_loss(prediction, gold) | ||||
| 
 | ||||
|         if self.PRINT_F and self.PRINT_TRAIN: | ||||
|             predictions_f = [x for x in predictions] | ||||
|             if apply_threshold: | ||||
|                 predictions_f = [float(1.0) if x > self.CUTOFF else float(0.0) for x in predictions_f] | ||||
|             p, r, f = run_el.evaluate(predictions_f, golds, to_print=False) | ||||
|             print("p/r/F train", round(p, 1), round(r, 1), round(f, 1)) | ||||
|             if self.PRINT_LOSS and self.PRINT_TRAIN: | ||||
|                 print("loss train", round(loss, 5)) | ||||
| 
 | ||||
|         d_scores = d_scores.reshape((-1, 1)) | ||||
|         d_scores = d_scores.astype(np.float32) | ||||
|         # print("d_scores", d_scores) | ||||
|             gradient = float(gradient) | ||||
|             # print("gradient", gradient) | ||||
|             # print("loss", loss) | ||||
| 
 | ||||
|         model_gradient = bp_model(d_scores, sgd=self.sgd) | ||||
|         # print("model_gradient", model_gradient) | ||||
|             model_gradient = bp_model(gradient, sgd=self.sgd) | ||||
|             # print("model_gradient", model_gradient) | ||||
| 
 | ||||
|         # concat = entity + doc, but doc is the same within this function (TODO: multiple docs/articles) | ||||
|         doc_gradient = model_gradient[0][self.ENTITY_WIDTH:] | ||||
|         entity_gradients = list() | ||||
|         for x in model_gradient: | ||||
|             entity_gradients.append(list(x[0:self.ENTITY_WIDTH])) | ||||
|             # concat = entity + doc, but doc is the same within this function (TODO: multiple docs/articles) | ||||
|             doc_gradient = model_gradient[0][self.ENTITY_WIDTH:] | ||||
|             entity_gradients = list() | ||||
|             for x in model_gradient: | ||||
|                 entity_gradients.append(list(x[0:self.ENTITY_WIDTH])) | ||||
| 
 | ||||
|         # print("doc_gradient", doc_gradient) | ||||
|         # print("entity_gradients", entity_gradients) | ||||
|             # print("doc_gradient", doc_gradient) | ||||
|             # print("entity_gradients", entity_gradients) | ||||
| 
 | ||||
|         bp_doc([doc_gradient], sgd=self.sgd_article) | ||||
|         bp_entity(entity_gradients, sgd=self.sgd_entity) | ||||
|             bp_doc([doc_gradient], sgd=self.sgd_article) | ||||
|             bp_entity(entity_gradients, sgd=self.sgd_entity) | ||||
| 
 | ||||
|     def _get_training_data(self, training_dir, entity_descr_output, dev, limit, balance, to_print): | ||||
|         id_to_descr = kb_creator._get_id_to_description(entity_descr_output) | ||||
|  | @ -326,16 +322,17 @@ class EL_Model: | |||
|                             pos_entities[article_id + "_" + mention] = descr | ||||
| 
 | ||||
|                     for mention, entity_negs in incorrect_entries[article_id].items(): | ||||
|                         neg_count = 0 | ||||
|                         for entity_neg in entity_negs: | ||||
|                             descr = id_to_descr.get(entity_neg) | ||||
|                             if descr: | ||||
|                         if not balance or pos_entities.get(article_id + "_" + mention): | ||||
|                             neg_count = 0 | ||||
|                             for entity_neg in entity_negs: | ||||
|                                 # if balance, keep only 1 negative instance for each positive instance | ||||
|                                 if neg_count < 1 or not balance: | ||||
|                                     descr_list = neg_entities.get(article_id + "_" + mention, []) | ||||
|                                     descr_list.append(descr) | ||||
|                                     neg_entities[article_id + "_" + mention] = descr_list | ||||
|                                     neg_count += 1 | ||||
|                                     descr = id_to_descr.get(entity_neg) | ||||
|                                     if descr: | ||||
|                                         descr_list = neg_entities.get(article_id + "_" + mention, []) | ||||
|                                         descr_list.append(descr) | ||||
|                                         neg_entities[article_id + "_" + mention] = descr_list | ||||
|                                         neg_count += 1 | ||||
| 
 | ||||
|         if to_print: | ||||
|             print() | ||||
|  |  | |||
|  | @ -111,7 +111,7 @@ if __name__ == "__main__": | |||
|         print("STEP 6: training", datetime.datetime.now()) | ||||
|         my_nlp = spacy.load('en_core_web_md') | ||||
|         trainer = EL_Model(kb=my_kb, nlp=my_nlp) | ||||
|         trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=1000, devlimit=20) | ||||
|         trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=100, devlimit=20) | ||||
|         print() | ||||
| 
 | ||||
|     # STEP 7: apply the EL algorithm on the dev dataset | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user