mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	obtain sentence for each mention
This commit is contained in:
		
							parent
							
								
									97241a3ed7
								
							
						
					
					
						commit
						4392c01b7b
					
				| 
						 | 
					@ -70,7 +70,7 @@ def is_dev(file_name):
 | 
				
			||||||
    return file_name.endswith("3.txt")
 | 
					    return file_name.endswith("3.txt")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def evaluate(predictions, golds, to_print=True):
 | 
					def evaluate(predictions, golds, to_print=True, times_hundred=True):
 | 
				
			||||||
    if len(predictions) != len(golds):
 | 
					    if len(predictions) != len(golds):
 | 
				
			||||||
        raise ValueError("predictions and gold entities should have the same length")
 | 
					        raise ValueError("predictions and gold entities should have the same length")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -101,8 +101,11 @@ def evaluate(predictions, golds, to_print=True):
 | 
				
			||||||
        print("fp", fp)
 | 
					        print("fp", fp)
 | 
				
			||||||
        print("fn", fn)
 | 
					        print("fn", fn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    precision = 100 * tp / (tp + fp + 0.0000001)
 | 
					    precision = tp / (tp + fp + 0.0000001)
 | 
				
			||||||
    recall = 100 * tp / (tp + fn + 0.0000001)
 | 
					    recall = tp / (tp + fn + 0.0000001)
 | 
				
			||||||
 | 
					    if times_hundred:
 | 
				
			||||||
 | 
					        precision = precision*100
 | 
				
			||||||
 | 
					        recall = recall*100
 | 
				
			||||||
    fscore = 2 * recall * precision / (recall + precision + 0.0000001)
 | 
					    fscore = 2 * recall * precision / (recall + precision + 0.0000001)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    accuracy = corrects / (corrects + incorrects)
 | 
					    accuracy = corrects / (corrects + incorrects)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,6 +20,7 @@ from thinc.t2t import ParametricAttention
 | 
				
			||||||
from thinc.misc import Residual
 | 
					from thinc.misc import Residual
 | 
				
			||||||
from thinc.misc import LayerNorm as LN
 | 
					from thinc.misc import LayerNorm as LN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from spacy.matcher import PhraseMatcher
 | 
				
			||||||
from spacy.tokens import Doc
 | 
					from spacy.tokens import Doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
""" TODO: this code needs to be implemented in pipes.pyx"""
 | 
					""" TODO: this code needs to be implemented in pipes.pyx"""
 | 
				
			||||||
| 
						 | 
					@ -27,13 +28,16 @@ from spacy.tokens import Doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class EL_Model:
 | 
					class EL_Model:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    PRINT_INSPECT = False
 | 
				
			||||||
    PRINT_TRAIN = False
 | 
					    PRINT_TRAIN = False
 | 
				
			||||||
    EPS = 0.0000000005
 | 
					    EPS = 0.0000000005
 | 
				
			||||||
    CUTOFF = 0.5
 | 
					    CUTOFF = 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    BATCH_SIZE = 5
 | 
					    BATCH_SIZE = 5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    INPUT_DIM = 300
 | 
					    DOC_CUTOFF = 300    # number of characters from the doc context
 | 
				
			||||||
 | 
					    INPUT_DIM = 300     # dimension of pre-trained vectors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    HIDDEN_1_WIDTH = 32   # 10
 | 
					    HIDDEN_1_WIDTH = 32   # 10
 | 
				
			||||||
    HIDDEN_2_WIDTH = 32  # 6
 | 
					    HIDDEN_2_WIDTH = 32  # 6
 | 
				
			||||||
    DESC_WIDTH = 64     # 4
 | 
					    DESC_WIDTH = 64     # 4
 | 
				
			||||||
| 
						 | 
					@ -58,11 +62,20 @@ class EL_Model:
 | 
				
			||||||
        # raise errors instead of runtime warnings in case of int/float overflow
 | 
					        # raise errors instead of runtime warnings in case of int/float overflow
 | 
				
			||||||
        np.seterr(all='raise')
 | 
					        np.seterr(all='raise')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        train_ent, train_gold, train_desc, train_article, train_texts = self._get_training_data(training_dir,
 | 
					        train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \
 | 
				
			||||||
                                                                                                entity_descr_output,
 | 
					            self._get_training_data(training_dir, entity_descr_output, False, trainlimit, to_print=False)
 | 
				
			||||||
                                                                                                False,
 | 
					
 | 
				
			||||||
                                                                                                trainlimit,
 | 
					        # inspect data
 | 
				
			||||||
                                                                                                to_print=False)
 | 
					        if self.PRINT_INSPECT:
 | 
				
			||||||
 | 
					            for entity in train_ent:
 | 
				
			||||||
 | 
					                print("entity", entity)
 | 
				
			||||||
 | 
					                print("gold", train_gold[entity])
 | 
				
			||||||
 | 
					                print("desc", train_desc[entity])
 | 
				
			||||||
 | 
					                print("sentence ID", train_sent[entity])
 | 
				
			||||||
 | 
					                print("sentence text", train_sent_texts[train_sent[entity]])
 | 
				
			||||||
 | 
					                print("article ID", train_art[entity])
 | 
				
			||||||
 | 
					                print("article text", train_art_texts[train_art[entity]])
 | 
				
			||||||
 | 
					                print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        train_pos_entities = [k for k,v in train_gold.items() if v]
 | 
					        train_pos_entities = [k for k,v in train_gold.items() if v]
 | 
				
			||||||
        train_neg_entities = [k for k,v in train_gold.items() if not v]
 | 
					        train_neg_entities = [k for k,v in train_gold.items() if not v]
 | 
				
			||||||
| 
						 | 
					@ -70,6 +83,10 @@ class EL_Model:
 | 
				
			||||||
        train_pos_count = len(train_pos_entities)
 | 
					        train_pos_count = len(train_pos_entities)
 | 
				
			||||||
        train_neg_count = len(train_neg_entities)
 | 
					        train_neg_count = len(train_neg_entities)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if to_print:
 | 
				
			||||||
 | 
					            print()
 | 
				
			||||||
 | 
					            print("Upsampling, original training instances pos/neg:", train_pos_count, train_neg_count)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # upsample positives to 50-50 distribution
 | 
					        # upsample positives to 50-50 distribution
 | 
				
			||||||
        while train_pos_count < train_neg_count:
 | 
					        while train_pos_count < train_neg_count:
 | 
				
			||||||
            train_ent.append(random.choice(train_pos_entities))
 | 
					            train_ent.append(random.choice(train_pos_entities))
 | 
				
			||||||
| 
						 | 
					@ -82,11 +99,8 @@ class EL_Model:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        shuffle(train_ent)
 | 
					        shuffle(train_ent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dev_ent, dev_gold, dev_desc, dev_article, dev_texts = self._get_training_data(training_dir,
 | 
					        dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts = \
 | 
				
			||||||
                                                                                      entity_descr_output,
 | 
					            self._get_training_data(training_dir, entity_descr_output, True, devlimit, to_print=False)
 | 
				
			||||||
                                                                                      True,
 | 
					 | 
				
			||||||
                                                                                      devlimit,
 | 
					 | 
				
			||||||
                                                                                      to_print=False)
 | 
					 | 
				
			||||||
        shuffle(dev_ent)
 | 
					        shuffle(dev_ent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dev_pos_count = len([g for g in dev_gold.values() if g])
 | 
					        dev_pos_count = len([g for g in dev_gold.values() if g])
 | 
				
			||||||
| 
						 | 
					@ -94,20 +108,16 @@ class EL_Model:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._begin_training()
 | 
					        self._begin_training()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_random", calc_random=True)
 | 
					 | 
				
			||||||
        print()
 | 
					 | 
				
			||||||
        self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_pre", avg=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if to_print:
 | 
					        if to_print:
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print("Training on", len(train_ent), "entities in", len(train_texts), "articles")
 | 
					            print("Training on", len(train_ent), "entities in", len(train_art_texts), "articles")
 | 
				
			||||||
            print("Training instances pos/neg", train_pos_count, train_neg_count)
 | 
					            print("Training instances pos/neg:", train_pos_count, train_neg_count)
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print("Dev test on", len(dev_ent), "entities in", len(dev_texts), "articles")
 | 
					            print("Dev test on", len(dev_ent), "entities in", len(dev_art_texts), "articles")
 | 
				
			||||||
            print("Dev instances pos/neg", dev_pos_count, dev_neg_count)
 | 
					            print("Dev instances pos/neg:", dev_pos_count, dev_neg_count)
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print(" CUTOFF", self.CUTOFF)
 | 
					            print(" CUTOFF", self.CUTOFF)
 | 
				
			||||||
 | 
					            print(" DOC_CUTOFF", self.DOC_CUTOFF)
 | 
				
			||||||
            print(" INPUT_DIM", self.INPUT_DIM)
 | 
					            print(" INPUT_DIM", self.INPUT_DIM)
 | 
				
			||||||
            print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH)
 | 
					            print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH)
 | 
				
			||||||
            print(" DESC_WIDTH", self.DESC_WIDTH)
 | 
					            print(" DESC_WIDTH", self.DESC_WIDTH)
 | 
				
			||||||
| 
						 | 
					@ -116,6 +126,10 @@ class EL_Model:
 | 
				
			||||||
            print(" DROP", self.DROP)
 | 
					            print(" DROP", self.DROP)
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, print_string="dev_random", calc_random=True)
 | 
				
			||||||
 | 
					        self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, print_string="dev_pre", avg=True)
 | 
				
			||||||
 | 
					        print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        start = 0
 | 
					        start = 0
 | 
				
			||||||
        stop = min(self.BATCH_SIZE, len(train_ent))
 | 
					        stop = min(self.BATCH_SIZE, len(train_ent))
 | 
				
			||||||
        processed = 0
 | 
					        processed = 0
 | 
				
			||||||
| 
						 | 
					@ -125,10 +139,10 @@ class EL_Model:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            golds = [train_gold[e] for e in next_batch]
 | 
					            golds = [train_gold[e] for e in next_batch]
 | 
				
			||||||
            descs = [train_desc[e] for e in next_batch]
 | 
					            descs = [train_desc[e] for e in next_batch]
 | 
				
			||||||
            articles = [train_texts[train_article[e]] for e in next_batch]
 | 
					            articles = [train_art_texts[train_art[e]] for e in next_batch]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.update(entities=next_batch, golds=golds, descs=descs, texts=articles)
 | 
					            self.update(entities=next_batch, golds=golds, descs=descs, texts=articles)
 | 
				
			||||||
            self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_inter", avg=True)
 | 
					            self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, print_string="dev_inter", avg=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            processed += len(next_batch)
 | 
					            processed += len(next_batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -151,7 +165,7 @@ class EL_Model:
 | 
				
			||||||
            predictions = self._predict(entities=entities, article_docs=article_docs, desc_docs=desc_docs, avg=avg)
 | 
					            predictions = self._predict(entities=entities, article_docs=article_docs, desc_docs=desc_docs, avg=avg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: combine with prior probability
 | 
					        # TODO: combine with prior probability
 | 
				
			||||||
        p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False)
 | 
					        p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False, times_hundred=False)
 | 
				
			||||||
        loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds))
 | 
					        loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print("p/r/F/acc/loss", print_string, round(p, 1), round(r, 1), round(f, 1), round(acc, 2), round(loss, 5))
 | 
					        print("p/r/F/acc/loss", print_string, round(p, 1), round(r, 1), round(f, 1), round(acc, 2), round(loss, 5))
 | 
				
			||||||
| 
						 | 
					@ -288,14 +302,18 @@ class EL_Model:
 | 
				
			||||||
                                                                                         collect_incorrect=True)
 | 
					                                                                                         collect_incorrect=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        local_vectors = list()   # TODO: local vectors
 | 
					        local_vectors = list()   # TODO: local vectors
 | 
				
			||||||
        text_by_article = dict()
 | 
					
 | 
				
			||||||
 | 
					        entities = set()
 | 
				
			||||||
        gold_by_entity = dict()
 | 
					        gold_by_entity = dict()
 | 
				
			||||||
        desc_by_entity = dict()
 | 
					        desc_by_entity = dict()
 | 
				
			||||||
        article_by_entity = dict()
 | 
					        article_by_entity = dict()
 | 
				
			||||||
        entities = list()
 | 
					        text_by_article = dict()
 | 
				
			||||||
 | 
					        sentence_by_entity = dict()
 | 
				
			||||||
 | 
					        text_by_sentence = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cnt = 0
 | 
					        cnt = 0
 | 
				
			||||||
        next_entity_nr = 0
 | 
					        next_entity_nr = 1
 | 
				
			||||||
 | 
					        next_sent_nr = 1
 | 
				
			||||||
        files = listdir(training_dir)
 | 
					        files = listdir(training_dir)
 | 
				
			||||||
        shuffle(files)
 | 
					        shuffle(files)
 | 
				
			||||||
        for f in files:
 | 
					        for f in files:
 | 
				
			||||||
| 
						 | 
					@ -305,33 +323,81 @@ class EL_Model:
 | 
				
			||||||
                    if cnt % 500 == 0 and to_print:
 | 
					                    if cnt % 500 == 0 and to_print:
 | 
				
			||||||
                        print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
 | 
					                        print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
 | 
				
			||||||
                    cnt += 1
 | 
					                    cnt += 1
 | 
				
			||||||
                    if article_id not in text_by_article:
 | 
					
 | 
				
			||||||
                        with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
 | 
					                    # parse the article text
 | 
				
			||||||
                            text = file.read()
 | 
					                    with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
 | 
				
			||||||
                            text_by_article[article_id] = text
 | 
					                        text = file.read()
 | 
				
			||||||
 | 
					                        article_doc = self.nlp(text)
 | 
				
			||||||
 | 
					                        truncated_text = text[0:min(self.DOC_CUTOFF, len(text))]
 | 
				
			||||||
 | 
					                        text_by_article[article_id] = truncated_text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # process all positive and negative entities, collect all relevant mentions in this article
 | 
				
			||||||
 | 
					                    article_terms = set()
 | 
				
			||||||
 | 
					                    entities_by_mention = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    for mention, entity_pos in correct_entries[article_id].items():
 | 
					                    for mention, entity_pos in correct_entries[article_id].items():
 | 
				
			||||||
                        descr = id_to_descr.get(entity_pos)
 | 
					                        descr = id_to_descr.get(entity_pos)
 | 
				
			||||||
                        if descr:
 | 
					                        if descr:
 | 
				
			||||||
                            entities.append(next_entity_nr)
 | 
					                            entity = "E_" + str(next_entity_nr) + "_" + article_id + "_" + mention
 | 
				
			||||||
                            gold_by_entity[next_entity_nr] = 1
 | 
					 | 
				
			||||||
                            desc_by_entity[next_entity_nr] = descr
 | 
					 | 
				
			||||||
                            article_by_entity[next_entity_nr] = article_id
 | 
					 | 
				
			||||||
                            next_entity_nr += 1
 | 
					                            next_entity_nr += 1
 | 
				
			||||||
 | 
					                            gold_by_entity[entity] = 1
 | 
				
			||||||
 | 
					                            desc_by_entity[entity] = descr
 | 
				
			||||||
 | 
					                            article_terms.add(mention)
 | 
				
			||||||
 | 
					                            mention_entities = entities_by_mention.get(mention, set())
 | 
				
			||||||
 | 
					                            mention_entities.add(entity)
 | 
				
			||||||
 | 
					                            entities_by_mention[mention] = mention_entities
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    for mention, entity_negs in incorrect_entries[article_id].items():
 | 
					                    for mention, entity_negs in incorrect_entries[article_id].items():
 | 
				
			||||||
                        for entity_neg in entity_negs:
 | 
					                        for entity_neg in entity_negs:
 | 
				
			||||||
                            descr = id_to_descr.get(entity_neg)
 | 
					                            descr = id_to_descr.get(entity_neg)
 | 
				
			||||||
                            if descr:
 | 
					                            if descr:
 | 
				
			||||||
                                entities.append(next_entity_nr)
 | 
					                                entity = "E_" + str(next_entity_nr) + "_" + article_id + "_" + mention
 | 
				
			||||||
                                gold_by_entity[next_entity_nr] = 0
 | 
					 | 
				
			||||||
                                desc_by_entity[next_entity_nr] = descr
 | 
					 | 
				
			||||||
                                article_by_entity[next_entity_nr] = article_id
 | 
					 | 
				
			||||||
                                next_entity_nr += 1
 | 
					                                next_entity_nr += 1
 | 
				
			||||||
 | 
					                                gold_by_entity[entity] = 0
 | 
				
			||||||
 | 
					                                desc_by_entity[entity] = descr
 | 
				
			||||||
 | 
					                                article_terms.add(mention)
 | 
				
			||||||
 | 
					                                mention_entities = entities_by_mention.get(mention, set())
 | 
				
			||||||
 | 
					                                mention_entities.add(entity)
 | 
				
			||||||
 | 
					                                entities_by_mention[mention] = mention_entities
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # find all matches in the doc for the mentions
 | 
				
			||||||
 | 
					                    # TODO: fix this - doesn't look like all entities are found
 | 
				
			||||||
 | 
					                    matcher = PhraseMatcher(self.nlp.vocab)
 | 
				
			||||||
 | 
					                    patterns = list(self.nlp.tokenizer.pipe(article_terms))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    matcher.add("TerminologyList", None, *patterns)
 | 
				
			||||||
 | 
					                    matches = matcher(article_doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # store sentences
 | 
				
			||||||
 | 
					                    sentence_to_id = dict()
 | 
				
			||||||
 | 
					                    for match_id, start, end in matches:
 | 
				
			||||||
 | 
					                        span = article_doc[start:end]
 | 
				
			||||||
 | 
					                        sent_text = span.sent
 | 
				
			||||||
 | 
					                        sent_nr = sentence_to_id.get(sent_text,  None)
 | 
				
			||||||
 | 
					                        if sent_nr is None:
 | 
				
			||||||
 | 
					                            sent_nr = "S_" + str(next_sent_nr) + article_id
 | 
				
			||||||
 | 
					                            next_sent_nr += 1
 | 
				
			||||||
 | 
					                            text_by_sentence[sent_nr] = sent_text
 | 
				
			||||||
 | 
					                            sentence_to_id[sent_text] = sent_nr
 | 
				
			||||||
 | 
					                        mention_entities = entities_by_mention[span.text]
 | 
				
			||||||
 | 
					                        for entity in mention_entities:
 | 
				
			||||||
 | 
					                            entities.add(entity)
 | 
				
			||||||
 | 
					                            sentence_by_entity[entity] = sent_nr
 | 
				
			||||||
 | 
					                            article_by_entity[entity] = article_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # remove entities that didn't have all data
 | 
				
			||||||
 | 
					        gold_by_entity = {k: v for k, v in gold_by_entity.items() if k in entities}
 | 
				
			||||||
 | 
					        desc_by_entity = {k: v for k, v in desc_by_entity.items() if k in entities}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        article_by_entity = {k: v for k, v in article_by_entity.items() if k in entities}
 | 
				
			||||||
 | 
					        text_by_article = {k: v for k, v in text_by_article.items() if k in article_by_entity.values()}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sentence_by_entity = {k: v for k, v in sentence_by_entity.items() if k in entities}
 | 
				
			||||||
 | 
					        text_by_sentence = {k: v for k, v in text_by_sentence.items() if k in sentence_by_entity.values()}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if to_print:
 | 
					        if to_print:
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
            print("Processed", cnt, "training articles, dev=" + str(dev))
 | 
					            print("Processed", cnt, "training articles, dev=" + str(dev))
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
        return entities, gold_by_entity, desc_by_entity, article_by_entity, text_by_article
 | 
					        return list(entities), gold_by_entity, desc_by_entity, article_by_entity, text_by_article, sentence_by_entity, text_by_sentence
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -111,7 +111,7 @@ if __name__ == "__main__":
 | 
				
			||||||
        print("STEP 6: training", datetime.datetime.now())
 | 
					        print("STEP 6: training", datetime.datetime.now())
 | 
				
			||||||
        my_nlp = spacy.load('en_core_web_md')
 | 
					        my_nlp = spacy.load('en_core_web_md')
 | 
				
			||||||
        trainer = EL_Model(kb=my_kb, nlp=my_nlp)
 | 
					        trainer = EL_Model(kb=my_kb, nlp=my_nlp)
 | 
				
			||||||
        trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=400, devlimit=50)
 | 
					        trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=100, devlimit=20)
 | 
				
			||||||
        print()
 | 
					        print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 7: apply the EL algorithm on the dev dataset
 | 
					    # STEP 7: apply the EL algorithm on the dev dataset
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user