mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	train and predict per article (saving time for doc encoding)
This commit is contained in:
		
							parent
							
								
									3b81b00954
								
							
						
					
					
						commit
						4142e8dd1b
					
				| 
						 | 
					@ -46,11 +46,11 @@ class EL_Model():
 | 
				
			||||||
        dev_instances, dev_pos, dev_neg, dev_doc = self._get_training_data(training_dir,
 | 
					        dev_instances, dev_pos, dev_neg, dev_doc = self._get_training_data(training_dir,
 | 
				
			||||||
                                                                           entity_descr_output,
 | 
					                                                                           entity_descr_output,
 | 
				
			||||||
                                                                           True,
 | 
					                                                                           True,
 | 
				
			||||||
                                                                           limit, to_print)
 | 
					                                                                           limit / 10, to_print)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if to_print:
 | 
					        if to_print:
 | 
				
			||||||
            print("Training on", len(train_instances), "instance clusters")
 | 
					            print("Training on", len(train_instances.values()), "articles")
 | 
				
			||||||
            print("Dev test on", len(dev_instances), "instance clusters")
 | 
					            print("Dev test on", len(dev_instances.values()), "articles")
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.sgd_entity = self.begin_training(self.entity_encoder)
 | 
					        self.sgd_entity = self.begin_training(self.entity_encoder)
 | 
				
			||||||
| 
						 | 
					@ -60,49 +60,51 @@ class EL_Model():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        losses = {}
 | 
					        losses = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for inst_cluster in train_instances:
 | 
					        instance_count = 0
 | 
				
			||||||
            pos_ex = train_pos.get(inst_cluster)
 | 
					
 | 
				
			||||||
            neg_exs = train_neg.get(inst_cluster, [])
 | 
					        for article_id, inst_cluster_set in train_instances.items():
 | 
				
			||||||
 | 
					            article_doc = train_doc[article_id]
 | 
				
			||||||
 | 
					            pos_ex_list = list()
 | 
				
			||||||
 | 
					            neg_exs_list = list()
 | 
				
			||||||
 | 
					            for inst_cluster in inst_cluster_set:
 | 
				
			||||||
 | 
					                instance_count += 1
 | 
				
			||||||
 | 
					                pos_ex_list.append(train_pos.get(inst_cluster))
 | 
				
			||||||
 | 
					                neg_exs_list.append(train_neg.get(inst_cluster, []))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.update(article_doc, pos_ex_list, neg_exs_list, losses=losses)
 | 
				
			||||||
 | 
					            p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc)
 | 
				
			||||||
 | 
					            print(round(fscore, 1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if to_print:
 | 
				
			||||||
 | 
					            print("Trained on", instance_count, "instance clusters")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if pos_ex and neg_exs:
 | 
					 | 
				
			||||||
                article = inst_cluster.split(sep="_")[0]
 | 
					 | 
				
			||||||
                entity_id = inst_cluster.split(sep="_")[1]
 | 
					 | 
				
			||||||
                article_doc = train_doc[article]
 | 
					 | 
				
			||||||
                self.update(article_doc, pos_ex, neg_exs, losses=losses)
 | 
					 | 
				
			||||||
                p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc)
 | 
					 | 
				
			||||||
                print(round(fscore, 1))
 | 
					 | 
				
			||||||
            # TODO
 | 
					 | 
				
			||||||
            # elif not pos_ex:
 | 
					 | 
				
			||||||
                # print("Weird. Couldn't find pos example for",  inst_cluster)
 | 
					 | 
				
			||||||
            # elif not neg_exs:
 | 
					 | 
				
			||||||
                # print("Weird. Couldn't find neg examples for",  inst_cluster)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _test_dev(self, dev_instances, dev_pos, dev_neg, dev_doc):
 | 
					    def _test_dev(self, dev_instances, dev_pos, dev_neg, dev_doc):
 | 
				
			||||||
        predictions = list()
 | 
					        predictions = list()
 | 
				
			||||||
        golds = list()
 | 
					        golds = list()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for inst_cluster in dev_instances:
 | 
					        for article_id, inst_cluster_set in dev_instances.items():
 | 
				
			||||||
            pos_ex = dev_pos.get(inst_cluster)
 | 
					            for inst_cluster in inst_cluster_set:
 | 
				
			||||||
            neg_exs = dev_neg.get(inst_cluster, [])
 | 
					                pos_ex = dev_pos.get(inst_cluster)
 | 
				
			||||||
            ex_to_id = dict()
 | 
					                neg_exs = dev_neg.get(inst_cluster, [])
 | 
				
			||||||
 | 
					                ex_to_id = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if pos_ex and neg_exs:
 | 
					                if pos_ex and neg_exs:
 | 
				
			||||||
                ex_to_id[pos_ex] = pos_ex._.entity_id
 | 
					                    ex_to_id[pos_ex] = pos_ex._.entity_id
 | 
				
			||||||
                for neg_ex in neg_exs:
 | 
					                    for neg_ex in neg_exs:
 | 
				
			||||||
                    ex_to_id[neg_ex] = neg_ex._.entity_id
 | 
					                        ex_to_id[neg_ex] = neg_ex._.entity_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                article = inst_cluster.split(sep="_")[0]
 | 
					                    article = inst_cluster.split(sep="_")[0]
 | 
				
			||||||
                entity_id = inst_cluster.split(sep="_")[1]
 | 
					                    entity_id = inst_cluster.split(sep="_")[1]
 | 
				
			||||||
                article_doc = dev_doc[article]
 | 
					                    article_doc = dev_doc[article]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                examples = list(neg_exs)
 | 
					                    examples = list(neg_exs)
 | 
				
			||||||
                examples.append(pos_ex)
 | 
					                    examples.append(pos_ex)
 | 
				
			||||||
                shuffle(examples)
 | 
					                    shuffle(examples)
 | 
				
			||||||
 | 
					 | 
				
			||||||
                best_entity, lowest_mse = self._predict(examples, article_doc)
 | 
					 | 
				
			||||||
                predictions.append(ex_to_id[best_entity])
 | 
					 | 
				
			||||||
                golds.append(ex_to_id[pos_ex])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    best_entity, lowest_mse = self._predict(examples, article_doc)
 | 
				
			||||||
 | 
					                    predictions.append(ex_to_id[best_entity])
 | 
				
			||||||
 | 
					                    golds.append(ex_to_id[pos_ex])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: use lowest_mse and combine with prior probability
 | 
					        # TODO: use lowest_mse and combine with prior probability
 | 
				
			||||||
        p, r, F = run_el.evaluate(predictions, golds, to_print=False)
 | 
					        p, r, F = run_el.evaluate(predictions, golds, to_print=False)
 | 
				
			||||||
| 
						 | 
					@ -161,60 +163,79 @@ class EL_Model():
 | 
				
			||||||
        sgd = create_default_optimizer(model.ops)
 | 
					        sgd = create_default_optimizer(model.ops)
 | 
				
			||||||
        return sgd
 | 
					        return sgd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, article_doc, true_entity, false_entities, drop=0., losses=None):
 | 
					    def update(self, article_doc, true_entity_list, false_entities_list, drop=0., losses=None):
 | 
				
			||||||
 | 
					        # TODO: one call only to begin_update ?
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entity_diffs = None
 | 
				
			||||||
 | 
					        doc_diffs = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop)
 | 
					        doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop)
 | 
					        for i, true_entity in enumerate(true_entity_list):
 | 
				
			||||||
        # print("encoding dim", len(true_entity_encoding[0]))
 | 
					            false_entities = false_entities_list[i]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        consensus_encoding = self._calculate_consensus(doc_encoding, true_entity_encoding)
 | 
					            true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop)
 | 
				
			||||||
        consensus_encoding_t = consensus_encoding.transpose()
 | 
					            # print("encoding dim", len(true_entity_encoding[0]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        doc_mse, doc_diffs = self._calculate_similarity(doc_encoding, consensus_encoding)
 | 
					            consensus_encoding = self._calculate_consensus(doc_encoding, true_entity_encoding)
 | 
				
			||||||
 | 
					            # consensus_encoding_t = consensus_encoding.transpose()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        entity_mses = list()
 | 
					            doc_mse, doc_diff = self._calculate_similarity(doc_encoding, consensus_encoding)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        true_mse, true_diffs = self._calculate_similarity(true_entity_encoding, consensus_encoding)
 | 
					            entity_mses = list()
 | 
				
			||||||
        # print("true_mse", true_mse)
 | 
					 | 
				
			||||||
        # print("true_diffs", true_diffs)
 | 
					 | 
				
			||||||
        entity_mses.append(true_mse)
 | 
					 | 
				
			||||||
        # true_exp = np.exp(true_entity_encoding.dot(consensus_encoding_t))
 | 
					 | 
				
			||||||
        # print("true_exp", true_exp)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # false_exp_sum = 0
 | 
					            true_mse, true_diffs = self._calculate_similarity(true_entity_encoding, consensus_encoding)
 | 
				
			||||||
 | 
					            # print("true_mse", true_mse)
 | 
				
			||||||
 | 
					            # print("true_diffs", true_diffs)
 | 
				
			||||||
 | 
					            entity_mses.append(true_mse)
 | 
				
			||||||
 | 
					            # true_exp = np.exp(true_entity_encoding.dot(consensus_encoding_t))
 | 
				
			||||||
 | 
					            # print("true_exp", true_exp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for false_entity in false_entities:
 | 
					            # false_exp_sum = 0
 | 
				
			||||||
            false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop)
 | 
					 | 
				
			||||||
            false_mse, false_diffs = self._calculate_similarity(false_entity_encoding, consensus_encoding)
 | 
					 | 
				
			||||||
            # print("false_mse", false_mse)
 | 
					 | 
				
			||||||
            # false_exp = np.exp(false_entity_encoding.dot(consensus_encoding_t))
 | 
					 | 
				
			||||||
            # print("false_exp", false_exp)
 | 
					 | 
				
			||||||
            # print("false_diffs", false_diffs)
 | 
					 | 
				
			||||||
            entity_mses.append(false_mse)
 | 
					 | 
				
			||||||
            # if false_mse > true_mse:
 | 
					 | 
				
			||||||
                # true_diffs = true_diffs - false_diffs ???
 | 
					 | 
				
			||||||
            # false_exp_sum += false_exp
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # prob = true_exp / false_exp_sum
 | 
					            if doc_diffs is not None:
 | 
				
			||||||
        # print("prob", prob)
 | 
					                doc_diffs += doc_diff
 | 
				
			||||||
 | 
					                entity_diffs += true_diffs
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                doc_diffs = doc_diff
 | 
				
			||||||
 | 
					                entity_diffs = true_diffs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        entity_mses = sorted(entity_mses)
 | 
					            for false_entity in false_entities:
 | 
				
			||||||
        # mse_sum = sum(entity_mses)
 | 
					                false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop)
 | 
				
			||||||
        # entity_probs = [1 - x/mse_sum for x in entity_mses]
 | 
					                false_mse, false_diffs = self._calculate_similarity(false_entity_encoding, consensus_encoding)
 | 
				
			||||||
        # print("entity_mses", entity_mses)
 | 
					                # print("false_mse", false_mse)
 | 
				
			||||||
        # print("entity_probs", entity_probs)
 | 
					                # false_exp = np.exp(false_entity_encoding.dot(consensus_encoding_t))
 | 
				
			||||||
        true_index = entity_mses.index(true_mse)
 | 
					                # print("false_exp", false_exp)
 | 
				
			||||||
        # print("true index", true_index)
 | 
					                # print("false_diffs", false_diffs)
 | 
				
			||||||
        # print("true prob", entity_probs[true_index])
 | 
					                entity_mses.append(false_mse)
 | 
				
			||||||
 | 
					                # if false_mse > true_mse:
 | 
				
			||||||
 | 
					                    # true_diffs = true_diffs - false_diffs ???
 | 
				
			||||||
 | 
					                # false_exp_sum += false_exp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # print("training loss", true_mse)
 | 
					            # prob = true_exp / false_exp_sum
 | 
				
			||||||
 | 
					            # print("prob", prob)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # print()
 | 
					            entity_mses = sorted(entity_mses)
 | 
				
			||||||
 | 
					            # mse_sum = sum(entity_mses)
 | 
				
			||||||
 | 
					            # entity_probs = [1 - x/mse_sum for x in entity_mses]
 | 
				
			||||||
 | 
					            # print("entity_mses", entity_mses)
 | 
				
			||||||
 | 
					            # print("entity_probs", entity_probs)
 | 
				
			||||||
 | 
					            true_index = entity_mses.index(true_mse)
 | 
				
			||||||
 | 
					            # print("true index", true_index)
 | 
				
			||||||
 | 
					            # print("true prob", entity_probs[true_index])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # print("training loss", true_mse)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: proper backpropagation taking ranking of elements into account ?
 | 
					        # TODO: proper backpropagation taking ranking of elements into account ?
 | 
				
			||||||
        # TODO backpropagation also for negative examples
 | 
					        # TODO backpropagation also for negative examples
 | 
				
			||||||
        true_entity_bp(true_diffs, sgd=self.sgd_entity)
 | 
					
 | 
				
			||||||
        article_bp(doc_diffs, sgd=self.sgd_article)
 | 
					        if doc_diffs is not None:
 | 
				
			||||||
 | 
					            doc_diffs = doc_diffs / len(true_entity_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            true_entity_bp(entity_diffs, sgd=self.sgd_entity)
 | 
				
			||||||
 | 
					            article_bp(doc_diffs, sgd=self.sgd_article)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO delete ?
 | 
					    # TODO delete ?
 | 
				
			||||||
| 
						 | 
					@ -268,7 +289,7 @@ class EL_Model():
 | 
				
			||||||
                                                                                         collect_incorrect=True)
 | 
					                                                                                         collect_incorrect=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        instances = list()
 | 
					        instance_by_doc = dict()
 | 
				
			||||||
        local_vectors = list()   # TODO: local vectors
 | 
					        local_vectors = list()   # TODO: local vectors
 | 
				
			||||||
        doc_by_article = dict()
 | 
					        doc_by_article = dict()
 | 
				
			||||||
        pos_entities = dict()
 | 
					        pos_entities = dict()
 | 
				
			||||||
| 
						 | 
					@ -280,18 +301,19 @@ class EL_Model():
 | 
				
			||||||
                if dev == run_el.is_dev(f):
 | 
					                if dev == run_el.is_dev(f):
 | 
				
			||||||
                    article_id = f.replace(".txt", "")
 | 
					                    article_id = f.replace(".txt", "")
 | 
				
			||||||
                    if cnt % 500 == 0 and to_print:
 | 
					                    if cnt % 500 == 0 and to_print:
 | 
				
			||||||
                        print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset")
 | 
					                        print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
 | 
				
			||||||
                    cnt += 1
 | 
					                    cnt += 1
 | 
				
			||||||
                    if article_id not in doc_by_article:
 | 
					                    if article_id not in doc_by_article:
 | 
				
			||||||
                        with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
 | 
					                        with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
 | 
				
			||||||
                            text = file.read()
 | 
					                            text = file.read()
 | 
				
			||||||
                            doc = self.nlp(text)
 | 
					                            doc = self.nlp(text)
 | 
				
			||||||
                            doc_by_article[article_id] = doc
 | 
					                            doc_by_article[article_id] = doc
 | 
				
			||||||
 | 
					                            instance_by_doc[article_id] = set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    for mention, entity_pos in correct_entries[article_id].items():
 | 
					                    for mention, entity_pos in correct_entries[article_id].items():
 | 
				
			||||||
                        descr = id_to_descr.get(entity_pos)
 | 
					                        descr = id_to_descr.get(entity_pos)
 | 
				
			||||||
                        if descr:
 | 
					                        if descr:
 | 
				
			||||||
                            instances.append(article_id + "_" + mention)
 | 
					                            instance_by_doc[article_id].add(article_id + "_" + mention)
 | 
				
			||||||
                            doc_descr = self.nlp(descr)
 | 
					                            doc_descr = self.nlp(descr)
 | 
				
			||||||
                            doc_descr._.entity_id = entity_pos
 | 
					                            doc_descr._.entity_id = entity_pos
 | 
				
			||||||
                            pos_entities[article_id + "_" + mention] = doc_descr
 | 
					                            pos_entities[article_id + "_" + mention] = doc_descr
 | 
				
			||||||
| 
						 | 
					@ -308,6 +330,6 @@ class EL_Model():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if to_print:
 | 
					        if to_print:
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print("Processed", cnt, "dev articles")
 | 
					            print("Processed", cnt, "training articles, dev=" + str(dev))
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
        return instances, pos_entities, neg_entities, doc_by_article
 | 
					        return instance_by_doc, pos_entities, neg_entities, doc_by_article
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -111,7 +111,7 @@ if __name__ == "__main__":
 | 
				
			||||||
        print("STEP 6: training ", datetime.datetime.now())
 | 
					        print("STEP 6: training ", datetime.datetime.now())
 | 
				
			||||||
        my_nlp = spacy.load('en_core_web_md')
 | 
					        my_nlp = spacy.load('en_core_web_md')
 | 
				
			||||||
        trainer = EL_Model(kb=my_kb, nlp=my_nlp)
 | 
					        trainer = EL_Model(kb=my_kb, nlp=my_nlp)
 | 
				
			||||||
        trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, limit=50)
 | 
					        trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, limit=500)
 | 
				
			||||||
        print()
 | 
					        print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 7: apply the EL algorithm on the dev dataset
 | 
					    # STEP 7: apply the EL algorithm on the dev dataset
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user