mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	speeding up training
This commit is contained in:
		
							parent
							
								
									66813a1fdc
								
							
						
					
					
						commit
						6521cfa132
					
				| 
						 | 
				
			
			@ -115,6 +115,7 @@ def run_pipeline():
 | 
			
		|||
 | 
			
		||||
    # STEP 6: create the entity linking pipe
 | 
			
		||||
    if train_pipe:
 | 
			
		||||
        print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
 | 
			
		||||
        train_limit = 100
 | 
			
		||||
        dev_limit = 20
 | 
			
		||||
        print("Training on", train_limit, "articles")
 | 
			
		||||
| 
						 | 
				
			
			@ -131,7 +132,7 @@ def run_pipeline():
 | 
			
		|||
                                                      training_dir=TRAINING_DIR,
 | 
			
		||||
                                                      dev=True,
 | 
			
		||||
                                                      limit=dev_limit,
 | 
			
		||||
                                                        to_print=False)
 | 
			
		||||
                                                      to_print=False)
 | 
			
		||||
 | 
			
		||||
        el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb, "doc_cutoff": DOC_CHAR_CUTOFF})
 | 
			
		||||
        nlp.add_pipe(el_pipe, last=True)
 | 
			
		||||
| 
						 | 
				
			
			@ -147,35 +148,40 @@ def run_pipeline():
 | 
			
		|||
 | 
			
		||||
            with nlp.disable_pipes(*other_pipes):
 | 
			
		||||
                for batch in batches:
 | 
			
		||||
                    docs, golds = zip(*batch)
 | 
			
		||||
                    nlp.update(
 | 
			
		||||
                        docs,
 | 
			
		||||
                        golds,
 | 
			
		||||
                        drop=DROPOUT,
 | 
			
		||||
                        losses=losses,
 | 
			
		||||
                    )
 | 
			
		||||
                    try:
 | 
			
		||||
                        docs, golds = zip(*batch)
 | 
			
		||||
                        nlp.update(
 | 
			
		||||
                            docs,
 | 
			
		||||
                            golds,
 | 
			
		||||
                            drop=DROPOUT,
 | 
			
		||||
                            losses=losses,
 | 
			
		||||
                        )
 | 
			
		||||
                    except Exception as e:
 | 
			
		||||
                        print("Error updating batch", e)
 | 
			
		||||
 | 
			
		||||
            # print(" measuring accuracy 1-1")
 | 
			
		||||
            el_pipe.context_weight = 1
 | 
			
		||||
            el_pipe.prior_weight = 1
 | 
			
		||||
            dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
 | 
			
		||||
            train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
 | 
			
		||||
            print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
 | 
			
		||||
 | 
			
		||||
            # print(" measuring accuracy 0-1")
 | 
			
		||||
            el_pipe.context_weight = 0
 | 
			
		||||
            el_pipe.prior_weight = 1
 | 
			
		||||
            dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe)
 | 
			
		||||
            train_acc_0_1 = _measure_accuracy(train_data, el_pipe)
 | 
			
		||||
        # baseline using only prior probabilities
 | 
			
		||||
        el_pipe.context_weight = 0
 | 
			
		||||
        el_pipe.prior_weight = 1
 | 
			
		||||
        dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe)
 | 
			
		||||
        train_acc_0_1 = _measure_accuracy(train_data, el_pipe)
 | 
			
		||||
 | 
			
		||||
            # print(" measuring accuracy 1-0")
 | 
			
		||||
            el_pipe.context_weight = 1
 | 
			
		||||
            el_pipe.prior_weight = 0
 | 
			
		||||
            dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
 | 
			
		||||
            train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
 | 
			
		||||
        # print(" measuring accuracy 1-1")
 | 
			
		||||
        el_pipe.context_weight = 1
 | 
			
		||||
        el_pipe.prior_weight = 1
 | 
			
		||||
        dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
 | 
			
		||||
        train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
 | 
			
		||||
 | 
			
		||||
            print("Epoch, train loss, train/dev acc, 1-1, 0-1, 1-0:", itn, round(losses['entity_linker'], 2),
 | 
			
		||||
                  round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/",
 | 
			
		||||
                  round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2))
 | 
			
		||||
        # print(" measuring accuracy 1-0")
 | 
			
		||||
        el_pipe.context_weight = 1
 | 
			
		||||
        el_pipe.prior_weight = 0
 | 
			
		||||
        dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
 | 
			
		||||
        train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
 | 
			
		||||
 | 
			
		||||
        print("train/dev acc, 1-1, 0-1, 1-0:" ,
 | 
			
		||||
              round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/",
 | 
			
		||||
              round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2))
 | 
			
		||||
 | 
			
		||||
    # test Entity Linker
 | 
			
		||||
    if to_test_pipeline:
 | 
			
		||||
| 
						 | 
				
			
			@ -193,26 +199,29 @@ def _measure_accuracy(data, el_pipe):
 | 
			
		|||
 | 
			
		||||
    docs = [d for d, g in data]
 | 
			
		||||
    docs = el_pipe.pipe(docs)
 | 
			
		||||
 | 
			
		||||
    golds = [g for d, g in data]
 | 
			
		||||
 | 
			
		||||
    for doc, gold in zip(docs, golds):
 | 
			
		||||
        correct_entries_per_article = dict()
 | 
			
		||||
        for entity in gold.links:
 | 
			
		||||
            start, end, gold_kb = entity
 | 
			
		||||
            correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
 | 
			
		||||
        try:
 | 
			
		||||
            correct_entries_per_article = dict()
 | 
			
		||||
            for entity in gold.links:
 | 
			
		||||
                start, end, gold_kb = entity
 | 
			
		||||
                correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
 | 
			
		||||
 | 
			
		||||
        for ent in doc.ents:
 | 
			
		||||
            if ent.label_ == "PERSON":  # TODO: expand to other types
 | 
			
		||||
                pred_entity = ent.kb_id_
 | 
			
		||||
                start = ent.start
 | 
			
		||||
                end = ent.end
 | 
			
		||||
                gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
 | 
			
		||||
                if gold_entity is not None:
 | 
			
		||||
                    if gold_entity == pred_entity:
 | 
			
		||||
                        correct += 1
 | 
			
		||||
                    else:
 | 
			
		||||
                        incorrect += 1
 | 
			
		||||
            for ent in doc.ents:
 | 
			
		||||
                if ent.label_ == "PERSON":  # TODO: expand to other types
 | 
			
		||||
                    pred_entity = ent.kb_id_
 | 
			
		||||
                    start = ent.start
 | 
			
		||||
                    end = ent.end
 | 
			
		||||
                    gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
 | 
			
		||||
                    if gold_entity is not None:
 | 
			
		||||
                        if gold_entity == pred_entity:
 | 
			
		||||
                            correct += 1
 | 
			
		||||
                        else:
 | 
			
		||||
                            incorrect += 1
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print("Error assessing accuracy", e)
 | 
			
		||||
 | 
			
		||||
    if correct == incorrect == 0:
 | 
			
		||||
        return 0
 | 
			
		||||
| 
						 | 
				
			
			@ -243,4 +252,4 @@ def run_el_toy_example(nlp, kb):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_pipeline()
 | 
			
		||||
    run_pipeline()
 | 
			
		||||
| 
						 | 
				
			
			@ -1220,8 +1220,13 @@ class EntityLinker(Pipe):
 | 
			
		|||
 | 
			
		||||
    def predict(self, docs):
 | 
			
		||||
        self.require_model()
 | 
			
		||||
 | 
			
		||||
        if isinstance(docs, Doc):
 | 
			
		||||
            docs = [docs]
 | 
			
		||||
 | 
			
		||||
        final_entities = list()
 | 
			
		||||
        final_kb_ids = list()
 | 
			
		||||
 | 
			
		||||
        for i, article_doc in enumerate(docs):
 | 
			
		||||
            doc_encoding = self.article_encoder([article_doc])
 | 
			
		||||
            for ent in article_doc.ents:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user