mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	refactor again to clusters of entities and cosine similarity
This commit is contained in:
		
							parent
							
								
									8c4aa076bc
								
							
						
					
					
						commit
						992fa92b66
					
				| 
						 | 
					@ -11,7 +11,7 @@ from thinc.neural._classes.convolution import ExtractWindow
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator
 | 
					from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic, Tok2Vec
 | 
					from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic, Tok2Vec, cosine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten
 | 
					from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten
 | 
				
			||||||
from thinc.v2v import Model, Maxout, Affine, ReLu
 | 
					from thinc.v2v import Model, Maxout, Affine, ReLu
 | 
				
			||||||
| 
						 | 
					@ -20,6 +20,7 @@ from thinc.t2t import ParametricAttention
 | 
				
			||||||
from thinc.misc import Residual
 | 
					from thinc.misc import Residual
 | 
				
			||||||
from thinc.misc import LayerNorm as LN
 | 
					from thinc.misc import LayerNorm as LN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from spacy.cli.pretrain import get_cossim_loss
 | 
				
			||||||
from spacy.matcher import PhraseMatcher
 | 
					from spacy.matcher import PhraseMatcher
 | 
				
			||||||
from spacy.tokens import Doc
 | 
					from spacy.tokens import Doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,20 +35,20 @@ class EL_Model:
 | 
				
			||||||
    CUTOFF = 0.5
 | 
					    CUTOFF = 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    BATCH_SIZE = 5
 | 
					    BATCH_SIZE = 5
 | 
				
			||||||
    UPSAMPLE = True
 | 
					    # UPSAMPLE = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DOC_CUTOFF = 300    # number of characters from the doc context
 | 
					    DOC_CUTOFF = 300    # number of characters from the doc context
 | 
				
			||||||
    INPUT_DIM = 300     # dimension of pre-trained vectors
 | 
					    INPUT_DIM = 300     # dimension of pre-trained vectors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # HIDDEN_1_WIDTH = 32   # 10
 | 
					    HIDDEN_1_WIDTH = 32
 | 
				
			||||||
    HIDDEN_2_WIDTH = 32  # 6
 | 
					    # HIDDEN_2_WIDTH = 32  # 6
 | 
				
			||||||
    DESC_WIDTH = 64     # 4
 | 
					    DESC_WIDTH = 64
 | 
				
			||||||
    ARTICLE_WIDTH = 64   # 8
 | 
					    ARTICLE_WIDTH = 64
 | 
				
			||||||
    SENT_WIDTH = 64
 | 
					    SENT_WIDTH = 64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DROP = 0.1
 | 
					    DROP = 0.1
 | 
				
			||||||
    LEARN_RATE = 0.0001
 | 
					    LEARN_RATE = 0.0001
 | 
				
			||||||
    EPOCHS = 20
 | 
					    EPOCHS = 10
 | 
				
			||||||
    L2 = 1e-6
 | 
					    L2 = 1e-6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    name = "entity_linker"
 | 
					    name = "entity_linker"
 | 
				
			||||||
| 
						 | 
					@ -57,9 +58,10 @@ class EL_Model:
 | 
				
			||||||
        self.nlp = nlp
 | 
					        self.nlp = nlp
 | 
				
			||||||
        self.kb = kb
 | 
					        self.kb = kb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._build_cnn(desc_width=self.DESC_WIDTH,
 | 
					        self._build_cnn(embed_width=self.INPUT_DIM,
 | 
				
			||||||
 | 
					                        desc_width=self.DESC_WIDTH,
 | 
				
			||||||
                        article_width=self.ARTICLE_WIDTH,
 | 
					                        article_width=self.ARTICLE_WIDTH,
 | 
				
			||||||
                        sent_width=self.SENT_WIDTH)
 | 
					                        sent_width=self.SENT_WIDTH, hidden_1_width=self.HIDDEN_1_WIDTH)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def train_model(self, training_dir, entity_descr_output, trainlimit=None, devlimit=None, to_print=True):
 | 
					    def train_model(self, training_dir, entity_descr_output, trainlimit=None, devlimit=None, to_print=True):
 | 
				
			||||||
        # raise errors instead of runtime warnings in case of int/float overflow
 | 
					        # raise errors instead of runtime warnings in case of int/float overflow
 | 
				
			||||||
| 
						 | 
					@ -70,24 +72,28 @@ class EL_Model:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \
 | 
					        train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \
 | 
				
			||||||
            self._get_training_data(training_dir, entity_descr_output, False, trainlimit, to_print=False)
 | 
					            self._get_training_data(training_dir, entity_descr_output, False, trainlimit, to_print=False)
 | 
				
			||||||
 | 
					        train_clusters = list(train_ent.keys())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts = \
 | 
					        dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts = \
 | 
				
			||||||
            self._get_training_data(training_dir, entity_descr_output, True, devlimit, to_print=False)
 | 
					            self._get_training_data(training_dir, entity_descr_output, True, devlimit, to_print=False)
 | 
				
			||||||
 | 
					        dev_clusters = list(dev_ent.keys())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dev_pos_count = len([g for g in dev_gold.values() if g])
 | 
					        dev_pos_count = len([g for g in dev_gold.values() if g])
 | 
				
			||||||
        dev_neg_count = len([g for g in dev_gold.values() if not g])
 | 
					        dev_neg_count = len([g for g in dev_gold.values() if not g])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # inspect data
 | 
					        # inspect data
 | 
				
			||||||
        if self.PRINT_INSPECT:
 | 
					        if self.PRINT_INSPECT:
 | 
				
			||||||
            for entity in train_ent:
 | 
					            for cluster, entities in train_ent.items():
 | 
				
			||||||
                print("entity", entity)
 | 
					 | 
				
			||||||
                print("gold", train_gold[entity])
 | 
					 | 
				
			||||||
                print("desc", train_desc[entity])
 | 
					 | 
				
			||||||
                print("sentence ID", train_sent[entity])
 | 
					 | 
				
			||||||
                print("sentence text", train_sent_texts[train_sent[entity]])
 | 
					 | 
				
			||||||
                print("article ID", train_art[entity])
 | 
					 | 
				
			||||||
                print("article text", train_art_texts[train_art[entity]])
 | 
					 | 
				
			||||||
                print()
 | 
					                print()
 | 
				
			||||||
 | 
					                for entity in entities:
 | 
				
			||||||
 | 
					                    print("entity", entity)
 | 
				
			||||||
 | 
					                    print("gold", train_gold[entity])
 | 
				
			||||||
 | 
					                    print("desc", train_desc[entity])
 | 
				
			||||||
 | 
					                    print("sentence ID", train_sent[entity])
 | 
				
			||||||
 | 
					                    print("sentence text", train_sent_texts[train_sent[entity]])
 | 
				
			||||||
 | 
					                    print("article ID", train_art[entity])
 | 
				
			||||||
 | 
					                    print("article text", train_art_texts[train_art[entity]])
 | 
				
			||||||
 | 
					                    print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        train_pos_entities = [k for k, v in train_gold.items() if v]
 | 
					        train_pos_entities = [k for k, v in train_gold.items() if v]
 | 
				
			||||||
        train_neg_entities = [k for k, v in train_gold.items() if not v]
 | 
					        train_neg_entities = [k for k, v in train_gold.items() if not v]
 | 
				
			||||||
| 
						 | 
					@ -95,29 +101,29 @@ class EL_Model:
 | 
				
			||||||
        train_pos_count = len(train_pos_entities)
 | 
					        train_pos_count = len(train_pos_entities)
 | 
				
			||||||
        train_neg_count = len(train_neg_entities)
 | 
					        train_neg_count = len(train_neg_entities)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.UPSAMPLE:
 | 
					        # if self.UPSAMPLE:
 | 
				
			||||||
            if to_print:
 | 
					        #    if to_print:
 | 
				
			||||||
                print()
 | 
					        #        print()
 | 
				
			||||||
                print("Upsampling, original training instances pos/neg:", train_pos_count, train_neg_count)
 | 
					        #        print("Upsampling, original training instances pos/neg:", train_pos_count, train_neg_count)
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
            # upsample positives to 50-50 distribution
 | 
					        #    # upsample positives to 50-50 distribution
 | 
				
			||||||
            while train_pos_count < train_neg_count:
 | 
					        #    while train_pos_count < train_neg_count:
 | 
				
			||||||
                train_ent.append(random.choice(train_pos_entities))
 | 
					        #        train_ent.append(random.choice(train_pos_entities))
 | 
				
			||||||
                train_pos_count += 1
 | 
					        #        train_pos_count += 1
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
            # upsample negatives to 50-50 distribution
 | 
					            # upsample negatives to 50-50 distribution
 | 
				
			||||||
            while train_neg_count < train_pos_count:
 | 
					        #    while train_neg_count < train_pos_count:
 | 
				
			||||||
                train_ent.append(random.choice(train_neg_entities))
 | 
					        #        train_ent.append(random.choice(train_neg_entities))
 | 
				
			||||||
                train_neg_count += 1
 | 
					        #        train_neg_count += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._begin_training()
 | 
					        self._begin_training()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if to_print:
 | 
					        if to_print:
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print("Training on", len(train_ent), "entities in", len(train_art_texts), "articles")
 | 
					            print("Training on", len(train_clusters), "entity clusters in", len(train_art_texts), "articles")
 | 
				
			||||||
            print("Training instances pos/neg:", train_pos_count, train_neg_count)
 | 
					            print("Training instances pos/neg:", train_pos_count, train_neg_count)
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print("Dev test on", len(dev_ent), "entities in", len(dev_art_texts), "articles")
 | 
					            print("Dev test on", len(dev_clusters), "entity clusters in", len(dev_art_texts), "articles")
 | 
				
			||||||
            print("Dev instances pos/neg:", dev_pos_count, dev_neg_count)
 | 
					            print("Dev instances pos/neg:", dev_pos_count, dev_neg_count)
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print(" CUTOFF", self.CUTOFF)
 | 
					            print(" CUTOFF", self.CUTOFF)
 | 
				
			||||||
| 
						 | 
					@ -138,94 +144,104 @@ class EL_Model:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
 | 
					        self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
 | 
				
			||||||
                       print_string="dev_pre", avg=True)
 | 
					                       print_string="dev_pre", avg=True)
 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        processed = 0
 | 
					        processed = 0
 | 
				
			||||||
        for i in range(self.EPOCHS):
 | 
					        for i in range(self.EPOCHS):
 | 
				
			||||||
            shuffle(train_ent)
 | 
					            shuffle(train_clusters)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            start = 0
 | 
					            start = 0
 | 
				
			||||||
            stop = min(self.BATCH_SIZE, len(train_ent))
 | 
					            stop = min(self.BATCH_SIZE, len(train_clusters))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            while start < len(train_ent):
 | 
					            while start < len(train_clusters):
 | 
				
			||||||
                next_batch = train_ent[start:stop]
 | 
					                next_batch = {c: train_ent[c] for c in train_clusters[start:stop]}
 | 
				
			||||||
 | 
					                processed += len(next_batch.keys())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                golds = [train_gold[e] for e in next_batch]
 | 
					                self.update(entity_clusters=next_batch, golds=train_gold, descs=train_desc,
 | 
				
			||||||
                descs = [train_desc[e] for e in next_batch]
 | 
					                            art_texts=train_art_texts, arts=train_art,
 | 
				
			||||||
                article_texts = [train_art_texts[train_art[e]] for e in next_batch]
 | 
					                            sent_texts=train_sent_texts, sents=train_sent)
 | 
				
			||||||
                sent_texts = [train_sent_texts[train_sent[e]] for e in next_batch]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                self.update(entities=next_batch, golds=golds, descs=descs, art_texts=article_texts, sent_texts=sent_texts)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                processed += len(next_batch)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                start = start + self.BATCH_SIZE
 | 
					                start = start + self.BATCH_SIZE
 | 
				
			||||||
                stop = min(stop + self.BATCH_SIZE, len(train_ent))
 | 
					                stop = min(stop + self.BATCH_SIZE, len(train_clusters))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if self.PRINT_TRAIN:
 | 
					            if self.PRINT_TRAIN:
 | 
				
			||||||
                print()
 | 
					                print()
 | 
				
			||||||
                self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts,
 | 
					                self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts,
 | 
				
			||||||
                               print_string="train_inter_epoch " + str(i), avg=True)
 | 
					                                print_string="train_inter_epoch " + str(i), avg=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
 | 
					            self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
 | 
				
			||||||
                           print_string="dev_inter_epoch " + str(i), avg=True)
 | 
					                           print_string="dev_inter_epoch " + str(i), avg=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if to_print:
 | 
					        if to_print:
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print("Trained on", processed, "entities across", self.EPOCHS, "epochs")
 | 
					            print("Trained on", processed, "entity clusters across", self.EPOCHS, "epochs")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _test_dev(self, entities, gold_by_entity, desc_by_entity, art_by_entity, art_texts, sent_by_entity, sent_texts,
 | 
					    def _test_dev(self, entity_clusters, golds, descs, arts, art_texts, sents, sent_texts,
 | 
				
			||||||
                  print_string, avg=True, calc_random=False):
 | 
					                  print_string, avg=True, calc_random=False):
 | 
				
			||||||
        golds = [gold_by_entity[e] for e in entities]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if calc_random:
 | 
					        correct = 0
 | 
				
			||||||
            predictions = self._predict_random(entities=entities)
 | 
					        incorrect = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        for cluster, entities in entity_clusters.items():
 | 
				
			||||||
            desc_docs = self.nlp.pipe([desc_by_entity[e] for e in entities])
 | 
					            correct_entities = [e for e in entities if golds[e]]
 | 
				
			||||||
            article_docs = self.nlp.pipe([art_texts[art_by_entity[e]] for e in entities])
 | 
					            incorrect_entities = [e for e in entities if not golds[e]]
 | 
				
			||||||
            sent_docs = self.nlp.pipe([sent_texts[sent_by_entity[e]] for e in entities])
 | 
					            assert len(correct_entities) == 1
 | 
				
			||||||
            predictions = self._predict(entities=entities, article_docs=article_docs, sent_docs=sent_docs,
 | 
					 | 
				
			||||||
                                        desc_docs=desc_docs, avg=avg)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: combine with prior probability
 | 
					            entities = list(entities)
 | 
				
			||||||
        p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False, times_hundred=False)
 | 
					            shuffle(entities)
 | 
				
			||||||
        loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print("p/r/F/acc/loss", print_string, round(p, 2), round(r, 2), round(f, 2), round(acc, 2), round(loss, 2))
 | 
					            if calc_random:
 | 
				
			||||||
 | 
					                predicted_entity = random.choice(entities)
 | 
				
			||||||
 | 
					                if predicted_entity in correct_entities:
 | 
				
			||||||
 | 
					                    correct += 1
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    incorrect += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return loss, p, r, f
 | 
					            else:
 | 
				
			||||||
 | 
					                desc_docs = self.nlp.pipe([descs[e] for e in entities])
 | 
				
			||||||
 | 
					                # article_texts = [art_texts[arts[e]] for e in entities]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _predict(self, entities, article_docs, sent_docs, desc_docs, avg=True, apply_threshold=True):
 | 
					                sent_doc = self.nlp(sent_texts[sents[cluster]])
 | 
				
			||||||
 | 
					                article_doc = self.nlp(art_texts[arts[cluster]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                predicted_index = self._predict(article_doc=article_doc, sent_doc=sent_doc,
 | 
				
			||||||
 | 
					                                                desc_docs=desc_docs, avg=avg)
 | 
				
			||||||
 | 
					                if entities[predicted_index] in correct_entities:
 | 
				
			||||||
 | 
					                    correct += 1
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    incorrect += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if correct == incorrect == 0:
 | 
				
			||||||
 | 
					            print("acc", print_string, "NA")
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        acc = correct / (correct + incorrect)
 | 
				
			||||||
 | 
					        print("acc", print_string, round(acc, 2))
 | 
				
			||||||
 | 
					        return acc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True):
 | 
				
			||||||
        if avg:
 | 
					        if avg:
 | 
				
			||||||
            with self.article_encoder.use_params(self.sgd_article.averages) \
 | 
					            with self.article_encoder.use_params(self.sgd_article.averages) \
 | 
				
			||||||
                 and self.desc_encoder.use_params(self.sgd_desc.averages):
 | 
					                 and self.desc_encoder.use_params(self.sgd_desc.averages)\
 | 
				
			||||||
                doc_encodings = self.article_encoder(article_docs)
 | 
					                 and self.sent_encoder.use_params(self.sgd_sent.averages):
 | 
				
			||||||
 | 
					                # doc_encoding = self.article_encoder(article_doc)
 | 
				
			||||||
                desc_encodings = self.desc_encoder(desc_docs)
 | 
					                desc_encodings = self.desc_encoder(desc_docs)
 | 
				
			||||||
                sent_encodings = self.sent_encoder(sent_docs)
 | 
					                sent_encoding = self.sent_encoder([sent_doc])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            doc_encodings = self.article_encoder(article_docs)
 | 
					            # doc_encodings = self.article_encoder(article_docs)
 | 
				
			||||||
            desc_encodings = self.desc_encoder(desc_docs)
 | 
					            desc_encodings = self.desc_encoder(desc_docs)
 | 
				
			||||||
            sent_encodings = self.sent_encoder(sent_docs)
 | 
					            sent_encoding = self.sent_encoder([sent_doc])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) + list(desc_encodings[i]) for i in
 | 
					        sent_enc = np.transpose(sent_encoding)
 | 
				
			||||||
                            range(len(entities))]
 | 
					        highest_sim = -5
 | 
				
			||||||
 | 
					        best_i = -1
 | 
				
			||||||
 | 
					        for i, desc_enc in enumerate(desc_encodings):
 | 
				
			||||||
 | 
					            sim = cosine(desc_enc, sent_enc)
 | 
				
			||||||
 | 
					            if sim >= highest_sim:
 | 
				
			||||||
 | 
					                best_i = i
 | 
				
			||||||
 | 
					                highest_sim = sim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        np_array_list = np.asarray(concat_encodings)
 | 
					        return best_i
 | 
				
			||||||
 | 
					 | 
				
			||||||
        if avg:
 | 
					 | 
				
			||||||
            with self.model.use_params(self.sgd.averages):
 | 
					 | 
				
			||||||
                predictions = self.model(np_array_list)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            predictions = self.model(np_array_list)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        predictions = self.model.ops.flatten(predictions)
 | 
					 | 
				
			||||||
        predictions = [float(p) for p in predictions]
 | 
					 | 
				
			||||||
        if apply_threshold:
 | 
					 | 
				
			||||||
            predictions = [float(1.0) if p > self.CUTOFF else float(0.0) for p in predictions]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return predictions
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _predict_random(self, entities, apply_threshold=True):
 | 
					    def _predict_random(self, entities, apply_threshold=True):
 | 
				
			||||||
        if not apply_threshold:
 | 
					        if not apply_threshold:
 | 
				
			||||||
| 
						 | 
					@ -233,47 +249,23 @@ class EL_Model:
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return [float(1.0) if random.uniform(0, 1) > self.CUTOFF else float(0.0) for _ in entities]
 | 
					            return [float(1.0) if random.uniform(0, 1) > self.CUTOFF else float(0.0) for _ in entities]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _build_cnn_depr(self, embed_width, desc_width, article_width, sent_width, hidden_1_width, hidden_2_width):
 | 
					    def _build_cnn(self, embed_width, desc_width, article_width, sent_width, hidden_1_width):
 | 
				
			||||||
        with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
 | 
					        with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
 | 
				
			||||||
            self.desc_encoder = self._encoder_depr(in_width=embed_width, hidden_with=hidden_1_width, end_width=desc_width)
 | 
					            self.desc_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width,
 | 
				
			||||||
            self.article_encoder = self._encoder_depr(in_width=embed_width, hidden_with=hidden_1_width, end_width=article_width)
 | 
					                                                   end_width=desc_width)
 | 
				
			||||||
            self.sent_encoder = self._encoder_depr(in_width=embed_width, hidden_with=hidden_1_width, end_width=sent_width)
 | 
					            self.article_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width,
 | 
				
			||||||
 | 
					                                                      end_width=article_width)
 | 
				
			||||||
 | 
					            self.sent_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width,
 | 
				
			||||||
 | 
					                                                   end_width=sent_width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            in_width = article_width + sent_width + desc_width
 | 
					    # def _encoder(self, width):
 | 
				
			||||||
            out_width = hidden_2_width
 | 
					    #    tok2vec = Tok2Vec(width=width, embed_size=2000, pretrained_vectors=self.nlp.vocab.vectors.name, cnn_maxout_pieces=3,
 | 
				
			||||||
 | 
					    #                      subword_features=False, conv_depth=4, bilstm_depth=0)
 | 
				
			||||||
            self.model = Affine(out_width, in_width) \
 | 
					    #
 | 
				
			||||||
                >> LN(Maxout(out_width, out_width)) \
 | 
					    #    return tok2vec >> flatten_add_lengths >> Pooling(mean_pool)
 | 
				
			||||||
                >> Affine(1, out_width) \
 | 
					 | 
				
			||||||
                >> logistic
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _build_cnn(self, desc_width, article_width, sent_width):
 | 
					 | 
				
			||||||
        with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
 | 
					 | 
				
			||||||
            self.desc_encoder = self._encoder(width=desc_width)
 | 
					 | 
				
			||||||
            self.article_encoder = self._encoder(width=article_width)
 | 
					 | 
				
			||||||
            self.sent_encoder = self._encoder(width=sent_width)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            in_width = desc_width + article_width + sent_width
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            self.model = Affine(self.HIDDEN_2_WIDTH, in_width) \
 | 
					 | 
				
			||||||
                         >> LN(Maxout(self.HIDDEN_2_WIDTH, self.HIDDEN_2_WIDTH)) \
 | 
					 | 
				
			||||||
                         >> Affine(1, self.HIDDEN_2_WIDTH) \
 | 
					 | 
				
			||||||
                         >> logistic
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # output_layer = (
 | 
					 | 
				
			||||||
            #         zero_init(Affine(1, in_width, drop_factor=0.0)) >> logistic
 | 
					 | 
				
			||||||
            # )
 | 
					 | 
				
			||||||
            # self.model = output_layer
 | 
					 | 
				
			||||||
            self.model.nO = 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _encoder(self, width):
 | 
					 | 
				
			||||||
        tok2vec = Tok2Vec(width=width, embed_size=2000, pretrained_vectors=self.nlp.vocab.vectors.name, cnn_maxout_pieces=3,
 | 
					 | 
				
			||||||
                          subword_features=False, conv_depth=4, bilstm_depth=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return tok2vec >> flatten_add_lengths >> Pooling(mean_pool)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def _encoder_depr(in_width, hidden_with, end_width):
 | 
					    def _encoder(in_width, hidden_with, end_width):
 | 
				
			||||||
        conv_depth = 2
 | 
					        conv_depth = 2
 | 
				
			||||||
        cnn_maxout_pieces = 3
 | 
					        cnn_maxout_pieces = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -307,64 +299,58 @@ class EL_Model:
 | 
				
			||||||
        self.sgd_desc.learn_rate = self.LEARN_RATE
 | 
					        self.sgd_desc.learn_rate = self.LEARN_RATE
 | 
				
			||||||
        self.sgd_desc.L2 = self.L2
 | 
					        self.sgd_desc.L2 = self.L2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.sgd = create_default_optimizer(self.model.ops)
 | 
					        # self.sgd = create_default_optimizer(self.model.ops)
 | 
				
			||||||
        self.sgd.learn_rate = self.LEARN_RATE
 | 
					        # self.sgd.learn_rate = self.LEARN_RATE
 | 
				
			||||||
        self.sgd.L2 = self.L2
 | 
					        # self.sgd.L2 = self.L2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def get_loss(predictions, golds):
 | 
					    def get_loss(predictions, golds):
 | 
				
			||||||
        d_scores = (predictions - golds)
 | 
					        loss, gradients = get_cossim_loss(predictions, golds)
 | 
				
			||||||
        gradient = d_scores.mean()
 | 
					        return loss, gradients
 | 
				
			||||||
        loss = (d_scores ** 2).mean()
 | 
					 | 
				
			||||||
        return loss, gradient
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, entities, golds, descs, art_texts, sent_texts):
 | 
					    def update(self, entity_clusters, golds, descs, art_texts, arts, sent_texts, sents):
 | 
				
			||||||
        golds = self.model.ops.asarray(golds)
 | 
					        for cluster, entities in entity_clusters.items():
 | 
				
			||||||
 | 
					            correct_entities = [e for e in entities if golds[e]]
 | 
				
			||||||
 | 
					            incorrect_entities = [e for e in entities if not golds[e]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        art_docs = self.nlp.pipe(art_texts)
 | 
					            assert len(correct_entities) == 1
 | 
				
			||||||
        sent_docs = self.nlp.pipe(sent_texts)
 | 
					            entities = list(entities)
 | 
				
			||||||
        desc_docs = self.nlp.pipe(descs)
 | 
					            shuffle(entities)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP)
 | 
					            # article_text = art_texts[arts[cluster]]
 | 
				
			||||||
        sent_encodings, bp_sent = self.sent_encoder.begin_update(sent_docs, drop=self.DROP)
 | 
					            cluster_sent = sent_texts[sents[cluster]]
 | 
				
			||||||
        desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) + list(desc_encodings[i])
 | 
					            # art_docs = self.nlp.pipe(article_text)
 | 
				
			||||||
                            for i in range(len(entities))]
 | 
					            sent_doc = self.nlp(cluster_sent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP)
 | 
					            for e in entities:
 | 
				
			||||||
        predictions = self.model.ops.flatten(predictions)
 | 
					                if golds[e]:
 | 
				
			||||||
 | 
					                 # TODO: more appropriate loss for the whole cluster (currently only pos entities)
 | 
				
			||||||
 | 
					                 #  TODO: speed up
 | 
				
			||||||
 | 
					                    desc_doc = self.nlp(descs[e])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # print("entities", entities)
 | 
					                    # doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP)
 | 
				
			||||||
        # print("predictions", predictions)
 | 
					                    sent_encodings, bp_sent = self.sent_encoder.begin_update([sent_doc], drop=self.DROP)
 | 
				
			||||||
        # print("golds", golds)
 | 
					                    desc_encodings, bp_desc = self.desc_encoder.begin_update([desc_doc], drop=self.DROP)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        loss, gradient = self.get_loss(predictions, golds)
 | 
					                    sent_encoding = sent_encodings[0]
 | 
				
			||||||
 | 
					                    desc_encoding = desc_encodings[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        gradient = float(gradient)
 | 
					                    sent_enc = self.sent_encoder.ops.asarray([sent_encoding])
 | 
				
			||||||
        # print("gradient", gradient)
 | 
					                    desc_enc = self.sent_encoder.ops.asarray([desc_encoding])
 | 
				
			||||||
        # print("loss", loss)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model_gradient = bp_model(gradient, sgd=self.sgd)
 | 
					                    # print("sent_encoding", type(sent_encoding), sent_encoding)
 | 
				
			||||||
        # print("model_gradient", model_gradient)
 | 
					                    # print("desc_encoding", type(desc_encoding), desc_encoding)
 | 
				
			||||||
 | 
					                    # print("getting los for entity", e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # concat = doc + sent + desc, but doc is the same within this function
 | 
					                    loss, gradient = self.get_loss(sent_enc, desc_enc)
 | 
				
			||||||
        sent_start = self.ARTICLE_WIDTH
 | 
					 | 
				
			||||||
        desc_start = self.ARTICLE_WIDTH + self.SENT_WIDTH
 | 
					 | 
				
			||||||
        doc_gradient = model_gradient[0][0:sent_start]
 | 
					 | 
				
			||||||
        sent_gradients = list()
 | 
					 | 
				
			||||||
        desc_gradients = list()
 | 
					 | 
				
			||||||
        for x in model_gradient:
 | 
					 | 
				
			||||||
            sent_gradients.append(list(x[sent_start:desc_start]))
 | 
					 | 
				
			||||||
            desc_gradients.append(list(x[desc_start:]))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # print("doc_gradient", doc_gradient)
 | 
					                    # print("gradient", gradient)
 | 
				
			||||||
        # print("sent_gradients", sent_gradients)
 | 
					                    # print("loss", loss)
 | 
				
			||||||
        # print("desc_gradients", desc_gradients)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        bp_doc([doc_gradient], sgd=self.sgd_article)
 | 
					                    bp_sent(gradient, sgd=self.sgd_sent)
 | 
				
			||||||
        bp_sent(sent_gradients, sgd=self.sgd_sent)
 | 
					                    # bp_desc(desc_gradients, sgd=self.sgd_desc)    TODO
 | 
				
			||||||
        bp_desc(desc_gradients, sgd=self.sgd_desc)
 | 
					                    # print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
 | 
					    def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
 | 
				
			||||||
        id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
 | 
					        id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
 | 
				
			||||||
| 
						 | 
					@ -373,13 +359,14 @@ class EL_Model:
 | 
				
			||||||
                                                                                         collect_correct=True,
 | 
					                                                                                         collect_correct=True,
 | 
				
			||||||
                                                                                         collect_incorrect=True)
 | 
					                                                                                         collect_incorrect=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        entities = set()
 | 
					        entities_by_cluster = dict()
 | 
				
			||||||
        gold_by_entity = dict()
 | 
					        gold_by_entity = dict()
 | 
				
			||||||
        desc_by_entity = dict()
 | 
					        desc_by_entity = dict()
 | 
				
			||||||
        article_by_entity = dict()
 | 
					        article_by_cluster = dict()
 | 
				
			||||||
        text_by_article = dict()
 | 
					        text_by_article = dict()
 | 
				
			||||||
        sentence_by_entity = dict()
 | 
					        sentence_by_cluster = dict()
 | 
				
			||||||
        text_by_sentence = dict()
 | 
					        text_by_sentence = dict()
 | 
				
			||||||
 | 
					        sentence_by_text = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cnt = 0
 | 
					        cnt = 0
 | 
				
			||||||
        next_entity_nr = 1
 | 
					        next_entity_nr = 1
 | 
				
			||||||
| 
						 | 
					@ -402,74 +389,69 @@ class EL_Model:
 | 
				
			||||||
                        text_by_article[article_id] = truncated_text
 | 
					                        text_by_article[article_id] = truncated_text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    # process all positive and negative entities, collect all relevant mentions in this article
 | 
					                    # process all positive and negative entities, collect all relevant mentions in this article
 | 
				
			||||||
                    article_terms = set()
 | 
					 | 
				
			||||||
                    entities_by_mention = dict()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    for mention, entity_pos in correct_entries[article_id].items():
 | 
					                    for mention, entity_pos in correct_entries[article_id].items():
 | 
				
			||||||
 | 
					                        cluster = article_id + "_" + mention
 | 
				
			||||||
                        descr = id_to_descr.get(entity_pos)
 | 
					                        descr = id_to_descr.get(entity_pos)
 | 
				
			||||||
 | 
					                        entities = set()
 | 
				
			||||||
                        if descr:
 | 
					                        if descr:
 | 
				
			||||||
                            entity = "E_" + str(next_entity_nr) + "_" + article_id + "_" + mention
 | 
					                            entity = "E_" + str(next_entity_nr) + "_" + cluster
 | 
				
			||||||
                            next_entity_nr += 1
 | 
					                            next_entity_nr += 1
 | 
				
			||||||
                            gold_by_entity[entity] = 1
 | 
					                            gold_by_entity[entity] = 1
 | 
				
			||||||
                            desc_by_entity[entity] = descr
 | 
					                            desc_by_entity[entity] = descr
 | 
				
			||||||
                            article_terms.add(mention)
 | 
					 | 
				
			||||||
                            mention_entities = entities_by_mention.get(mention, set())
 | 
					 | 
				
			||||||
                            mention_entities.add(entity)
 | 
					 | 
				
			||||||
                            entities_by_mention[mention] = mention_entities
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    for mention, entity_negs in incorrect_entries[article_id].items():
 | 
					 | 
				
			||||||
                        for entity_neg in entity_negs:
 | 
					 | 
				
			||||||
                            descr = id_to_descr.get(entity_neg)
 | 
					 | 
				
			||||||
                            if descr:
 | 
					 | 
				
			||||||
                                entity = "E_" + str(next_entity_nr) + "_" + article_id + "_" + mention
 | 
					 | 
				
			||||||
                                next_entity_nr += 1
 | 
					 | 
				
			||||||
                                gold_by_entity[entity] = 0
 | 
					 | 
				
			||||||
                                desc_by_entity[entity] = descr
 | 
					 | 
				
			||||||
                                article_terms.add(mention)
 | 
					 | 
				
			||||||
                                mention_entities = entities_by_mention.get(mention, set())
 | 
					 | 
				
			||||||
                                mention_entities.add(entity)
 | 
					 | 
				
			||||||
                                entities_by_mention[mention] = mention_entities
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # find all matches in the doc for the mentions
 | 
					 | 
				
			||||||
                    # TODO: fix this - doesn't look like all entities are found
 | 
					 | 
				
			||||||
                    matcher = PhraseMatcher(self.nlp.vocab)
 | 
					 | 
				
			||||||
                    patterns = list(self.nlp.tokenizer.pipe(article_terms))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    matcher.add("TerminologyList", None, *patterns)
 | 
					 | 
				
			||||||
                    matches = matcher(article_doc)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    # store sentences
 | 
					 | 
				
			||||||
                    sentence_to_id = dict()
 | 
					 | 
				
			||||||
                    for match_id, start, end in matches:
 | 
					 | 
				
			||||||
                        span = article_doc[start:end]
 | 
					 | 
				
			||||||
                        sent_text = span.sent.text
 | 
					 | 
				
			||||||
                        sent_nr = sentence_to_id.get(sent_text,  None)
 | 
					 | 
				
			||||||
                        mention = span.text
 | 
					 | 
				
			||||||
                        if sent_nr is None:
 | 
					 | 
				
			||||||
                            sent_nr = "S_" + str(next_sent_nr) + article_id
 | 
					 | 
				
			||||||
                            next_sent_nr += 1
 | 
					 | 
				
			||||||
                            text_by_sentence[sent_nr] = sent_text
 | 
					 | 
				
			||||||
                            sentence_to_id[sent_text] = sent_nr
 | 
					 | 
				
			||||||
                        mention_entities = entities_by_mention[mention]
 | 
					 | 
				
			||||||
                        for entity in mention_entities:
 | 
					 | 
				
			||||||
                            entities.add(entity)
 | 
					                            entities.add(entity)
 | 
				
			||||||
                            sentence_by_entity[entity] = sent_nr
 | 
					 | 
				
			||||||
                            article_by_entity[entity] = article_id
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # remove entities that didn't have all data
 | 
					                            entity_negs = incorrect_entries[article_id][mention]
 | 
				
			||||||
        gold_by_entity = {k: v for k, v in gold_by_entity.items() if k in entities}
 | 
					                            for entity_neg in entity_negs:
 | 
				
			||||||
        desc_by_entity = {k: v for k, v in desc_by_entity.items() if k in entities}
 | 
					                                descr = id_to_descr.get(entity_neg)
 | 
				
			||||||
 | 
					                                if descr:
 | 
				
			||||||
 | 
					                                    entity = "E_" + str(next_entity_nr) + "_" + cluster
 | 
				
			||||||
 | 
					                                    next_entity_nr += 1
 | 
				
			||||||
 | 
					                                    gold_by_entity[entity] = 0
 | 
				
			||||||
 | 
					                                    desc_by_entity[entity] = descr
 | 
				
			||||||
 | 
					                                    entities.add(entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        article_by_entity = {k: v for k, v in article_by_entity.items() if k in entities}
 | 
					                        found_matches = 0
 | 
				
			||||||
        text_by_article = {k: v for k, v in text_by_article.items() if k in article_by_entity.values()}
 | 
					                        if len(entities) > 1:
 | 
				
			||||||
 | 
					                            entities_by_cluster[cluster] = entities
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            # find all matches in the doc for the mentions
 | 
				
			||||||
 | 
					                            # TODO: fix this - doesn't look like all entities are found
 | 
				
			||||||
 | 
					                            matcher = PhraseMatcher(self.nlp.vocab)
 | 
				
			||||||
 | 
					                            patterns = list(self.nlp.tokenizer.pipe([mention]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            matcher.add("TerminologyList", None, *patterns)
 | 
				
			||||||
 | 
					                            matches = matcher(article_doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            # store sentences
 | 
				
			||||||
 | 
					                            for match_id, start, end in matches:
 | 
				
			||||||
 | 
					                                found_matches += 1
 | 
				
			||||||
 | 
					                                span = article_doc[start:end]
 | 
				
			||||||
 | 
					                                assert mention == span.text
 | 
				
			||||||
 | 
					                                sent_text = span.sent.text
 | 
				
			||||||
 | 
					                                sent_nr = sentence_by_text.get(sent_text,  None)
 | 
				
			||||||
 | 
					                                if sent_nr is None:
 | 
				
			||||||
 | 
					                                    sent_nr = "S_" + str(next_sent_nr) + article_id
 | 
				
			||||||
 | 
					                                    next_sent_nr += 1
 | 
				
			||||||
 | 
					                                    text_by_sentence[sent_nr] = sent_text
 | 
				
			||||||
 | 
					                                    sentence_by_text[sent_text] = sent_nr
 | 
				
			||||||
 | 
					                                article_by_cluster[cluster] = article_id
 | 
				
			||||||
 | 
					                                sentence_by_cluster[cluster] = sent_nr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        if found_matches == 0:
 | 
				
			||||||
 | 
					                            # TODO print("Could not find neg instances or sentence matches for", mention, "in", article_id)
 | 
				
			||||||
 | 
					                            entities_by_cluster.pop(cluster, None)
 | 
				
			||||||
 | 
					                            article_by_cluster.pop(cluster, None)
 | 
				
			||||||
 | 
					                            sentence_by_cluster.pop(cluster, None)
 | 
				
			||||||
 | 
					                            for entity in entities:
 | 
				
			||||||
 | 
					                                gold_by_entity.pop(entity, None)
 | 
				
			||||||
 | 
					                                desc_by_entity.pop(entity, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sentence_by_entity = {k: v for k, v in sentence_by_entity.items() if k in entities}
 | 
					 | 
				
			||||||
        text_by_sentence = {k: v for k, v in text_by_sentence.items() if k in sentence_by_entity.values()}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if to_print:
 | 
					        if to_print:
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print("Processed", cnt, "training articles, dev=" + str(dev))
 | 
					            print("Processed", cnt, "training articles, dev=" + str(dev))
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
        return list(entities), gold_by_entity, desc_by_entity, article_by_entity, text_by_article, \
 | 
					        return entities_by_cluster, gold_by_entity, desc_by_entity, article_by_cluster, text_by_article, \
 | 
				
			||||||
               sentence_by_entity, text_by_sentence
 | 
					               sentence_by_cluster, text_by_sentence
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -111,7 +111,7 @@ if __name__ == "__main__":
 | 
				
			||||||
        print("STEP 6: training", datetime.datetime.now())
 | 
					        print("STEP 6: training", datetime.datetime.now())
 | 
				
			||||||
        my_nlp = spacy.load('en_core_web_md')
 | 
					        my_nlp = spacy.load('en_core_web_md')
 | 
				
			||||||
        trainer = EL_Model(kb=my_kb, nlp=my_nlp)
 | 
					        trainer = EL_Model(kb=my_kb, nlp=my_nlp)
 | 
				
			||||||
        trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=100, devlimit=20)
 | 
					        trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=1000, devlimit=100)
 | 
				
			||||||
        print()
 | 
					        print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 7: apply the EL algorithm on the dev dataset
 | 
					    # STEP 7: apply the EL algorithm on the dev dataset
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user