mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	performance per entity type
This commit is contained in:
		
							parent
							
								
									b312f2d0e7
								
							
						
					
					
						commit
						81731907ba
					
				| 
						 | 
					@ -15,10 +15,10 @@ INPUT_DIM = 300  # dimension of pre-trained vectors
 | 
				
			||||||
DESC_WIDTH = 64
 | 
					DESC_WIDTH = 64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_kb(nlp, max_entities_per_alias, min_occ,
 | 
					def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
 | 
				
			||||||
              entity_def_output, entity_descr_output,
 | 
					              entity_def_output, entity_descr_output,
 | 
				
			||||||
              count_input, prior_prob_input, to_print=False):
 | 
					              count_input, prior_prob_input, to_print=False):
 | 
				
			||||||
    """ Create the knowledge base from Wikidata entries """
 | 
					    # Create the knowledge base from Wikidata entries
 | 
				
			||||||
    kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
 | 
					    kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # disable this part of the pipeline when rerunning the KB generation from preprocessed files
 | 
					    # disable this part of the pipeline when rerunning the KB generation from preprocessed files
 | 
				
			||||||
| 
						 | 
					@ -37,21 +37,26 @@ def create_kb(nlp, max_entities_per_alias, min_occ,
 | 
				
			||||||
        title_to_id = _get_entity_to_id(entity_def_output)
 | 
					        title_to_id = _get_entity_to_id(entity_def_output)
 | 
				
			||||||
        id_to_descr = _get_id_to_description(entity_descr_output)
 | 
					        id_to_descr = _get_id_to_description(entity_descr_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    title_list = list(title_to_id.keys())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # TODO: remove this filter (just for quicker testing of code)
 | 
					 | 
				
			||||||
    # title_list = title_list[0:342]
 | 
					 | 
				
			||||||
    # title_to_id = {t: title_to_id[t] for t in title_list}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    entity_list = [title_to_id[x] for x in title_list]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Currently keeping entities from the KB where there is no description - putting a default void description
 | 
					 | 
				
			||||||
    description_list = [id_to_descr.get(x, "No description defined") for x in entity_list]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    print()
 | 
					    print()
 | 
				
			||||||
    print(" * _get_entity_frequencies", datetime.datetime.now())
 | 
					    print(" * _get_entity_frequencies", datetime.datetime.now())
 | 
				
			||||||
    print()
 | 
					    print()
 | 
				
			||||||
    entity_frequencies = wp.get_entity_frequencies(count_input=count_input, entities=title_list)
 | 
					    entity_frequencies = wp.get_all_frequencies(count_input=count_input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # filter the entities for in the KB by frequency, because there's just too much data otherwise
 | 
				
			||||||
 | 
					    filtered_title_to_id = dict()
 | 
				
			||||||
 | 
					    entity_list = list()
 | 
				
			||||||
 | 
					    description_list = list()
 | 
				
			||||||
 | 
					    frequency_list = list()
 | 
				
			||||||
 | 
					    for title, entity in title_to_id.items():
 | 
				
			||||||
 | 
					        freq = entity_frequencies.get(title, 0)
 | 
				
			||||||
 | 
					        desc = id_to_descr.get(entity, None)
 | 
				
			||||||
 | 
					        if desc and freq > min_entity_freq:
 | 
				
			||||||
 | 
					            entity_list.append(entity)
 | 
				
			||||||
 | 
					            description_list.append(desc)
 | 
				
			||||||
 | 
					            frequency_list.append(freq)
 | 
				
			||||||
 | 
					            filtered_title_to_id[title] = entity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print()
 | 
					    print()
 | 
				
			||||||
    print(" * train entity encoder", datetime.datetime.now())
 | 
					    print(" * train entity encoder", datetime.datetime.now())
 | 
				
			||||||
| 
						 | 
					@ -67,12 +72,12 @@ def create_kb(nlp, max_entities_per_alias, min_occ,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print()
 | 
					    print()
 | 
				
			||||||
    print(" * adding", len(entity_list), "entities", datetime.datetime.now())
 | 
					    print(" * adding", len(entity_list), "entities", datetime.datetime.now())
 | 
				
			||||||
    kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=embeddings)
 | 
					    kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print()
 | 
					    print()
 | 
				
			||||||
    print(" * adding aliases", datetime.datetime.now())
 | 
					    print(" * adding aliases", datetime.datetime.now())
 | 
				
			||||||
    print()
 | 
					    print()
 | 
				
			||||||
    _add_aliases(kb, title_to_id=title_to_id,
 | 
					    _add_aliases(kb, title_to_id=filtered_title_to_id,
 | 
				
			||||||
                 max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
 | 
					                 max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
 | 
				
			||||||
                 prior_prob_input=prior_prob_input)
 | 
					                 prior_prob_input=prior_prob_input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,7 +21,7 @@ ENTITY_FILE = "gold_entities.csv"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_training(entity_def_input, training_output):
 | 
					def create_training(entity_def_input, training_output):
 | 
				
			||||||
    wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
 | 
					    wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
 | 
				
			||||||
    _process_wikipedia_texts(wp_to_id, training_output, limit=100000000)  # TODO: full dataset   100000000
 | 
					    _process_wikipedia_texts(wp_to_id, training_output, limit=100000000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
 | 
					def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,6 +29,7 @@ NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
 | 
				
			||||||
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
 | 
					TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MAX_CANDIDATES = 10
 | 
					MAX_CANDIDATES = 10
 | 
				
			||||||
 | 
					MIN_ENTITY_FREQ = 200
 | 
				
			||||||
MIN_PAIR_OCC = 5
 | 
					MIN_PAIR_OCC = 5
 | 
				
			||||||
DOC_SENT_CUTOFF = 2
 | 
					DOC_SENT_CUTOFF = 2
 | 
				
			||||||
EPOCHS = 10
 | 
					EPOCHS = 10
 | 
				
			||||||
| 
						 | 
					@ -46,14 +47,14 @@ def run_pipeline():
 | 
				
			||||||
    # one-time methods to create KB and write to file
 | 
					    # one-time methods to create KB and write to file
 | 
				
			||||||
    to_create_prior_probs = False
 | 
					    to_create_prior_probs = False
 | 
				
			||||||
    to_create_entity_counts = False
 | 
					    to_create_entity_counts = False
 | 
				
			||||||
    to_create_kb = False  # TODO: entity_defs should also contain entities not in the KB
 | 
					    to_create_kb = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # read KB back in from file
 | 
					    # read KB back in from file
 | 
				
			||||||
    to_read_kb = False
 | 
					    to_read_kb = False
 | 
				
			||||||
    to_test_kb = False
 | 
					    to_test_kb = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # create training dataset
 | 
					    # create training dataset
 | 
				
			||||||
    create_wp_training = True
 | 
					    create_wp_training = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # train the EL pipe
 | 
					    # train the EL pipe
 | 
				
			||||||
    train_pipe = False
 | 
					    train_pipe = False
 | 
				
			||||||
| 
						 | 
					@ -85,6 +86,7 @@ def run_pipeline():
 | 
				
			||||||
        print("STEP 3a: to_create_kb", datetime.datetime.now())
 | 
					        print("STEP 3a: to_create_kb", datetime.datetime.now())
 | 
				
			||||||
        kb_1 = kb_creator.create_kb(nlp_1,
 | 
					        kb_1 = kb_creator.create_kb(nlp_1,
 | 
				
			||||||
                                    max_entities_per_alias=MAX_CANDIDATES,
 | 
					                                    max_entities_per_alias=MAX_CANDIDATES,
 | 
				
			||||||
 | 
					                                    min_entity_freq=MIN_ENTITY_FREQ,
 | 
				
			||||||
                                    min_occ=MIN_PAIR_OCC,
 | 
					                                    min_occ=MIN_PAIR_OCC,
 | 
				
			||||||
                                    entity_def_output=ENTITY_DEFS,
 | 
					                                    entity_def_output=ENTITY_DEFS,
 | 
				
			||||||
                                    entity_descr_output=ENTITY_DESCR,
 | 
					                                    entity_descr_output=ENTITY_DESCR,
 | 
				
			||||||
| 
						 | 
					@ -112,7 +114,7 @@ def run_pipeline():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # test KB
 | 
					        # test KB
 | 
				
			||||||
        if to_test_kb:
 | 
					        if to_test_kb:
 | 
				
			||||||
            run_el.run_kb_toy_example(kb=kb_2)
 | 
					            test_kb(kb_2)
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 5: create a training dataset from WP
 | 
					    # STEP 5: create a training dataset from WP
 | 
				
			||||||
| 
						 | 
					@ -121,10 +123,18 @@ def run_pipeline():
 | 
				
			||||||
        training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
 | 
					        training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 6: create the entity linking pipe
 | 
					    # STEP 6: create the entity linking pipe
 | 
				
			||||||
 | 
					    el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
 | 
				
			||||||
 | 
					    el_pipe.set_kb(kb_2)
 | 
				
			||||||
 | 
					    nlp_2.add_pipe(el_pipe, last=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
 | 
				
			||||||
 | 
					    with nlp_2.disable_pipes(*other_pipes):  # only train Entity Linking
 | 
				
			||||||
 | 
					        nlp_2.begin_training()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if train_pipe:
 | 
					    if train_pipe:
 | 
				
			||||||
        print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
 | 
					        print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
 | 
				
			||||||
        train_limit = 50
 | 
					        train_limit = 10
 | 
				
			||||||
        dev_limit = 10
 | 
					        dev_limit = 2
 | 
				
			||||||
        print("Training on", train_limit, "articles")
 | 
					        print("Training on", train_limit, "articles")
 | 
				
			||||||
        print("Dev testing on", dev_limit, "articles")
 | 
					        print("Dev testing on", dev_limit, "articles")
 | 
				
			||||||
        print()
 | 
					        print()
 | 
				
			||||||
| 
						 | 
					@ -141,14 +151,6 @@ def run_pipeline():
 | 
				
			||||||
                                                      limit=dev_limit,
 | 
					                                                      limit=dev_limit,
 | 
				
			||||||
                                                      to_print=False)
 | 
					                                                      to_print=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
 | 
					 | 
				
			||||||
        el_pipe.set_kb(kb_2)
 | 
					 | 
				
			||||||
        nlp_2.add_pipe(el_pipe, last=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
 | 
					 | 
				
			||||||
        with nlp_2.disable_pipes(*other_pipes):  # only train Entity Linking
 | 
					 | 
				
			||||||
            nlp_2.begin_training()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for itn in range(EPOCHS):
 | 
					        for itn in range(EPOCHS):
 | 
				
			||||||
            random.shuffle(train_data)
 | 
					            random.shuffle(train_data)
 | 
				
			||||||
            losses = {}
 | 
					            losses = {}
 | 
				
			||||||
| 
						 | 
					@ -180,30 +182,32 @@ def run_pipeline():
 | 
				
			||||||
            # print(" measuring accuracy 1-1")
 | 
					            # print(" measuring accuracy 1-1")
 | 
				
			||||||
            el_pipe.context_weight = 1
 | 
					            el_pipe.context_weight = 1
 | 
				
			||||||
            el_pipe.prior_weight = 1
 | 
					            el_pipe.prior_weight = 1
 | 
				
			||||||
            dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
 | 
					            dev_acc_1_1, dev_acc_1_1_dict = _measure_accuracy(dev_data, el_pipe)
 | 
				
			||||||
            train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
 | 
					            print("dev acc combo:", round(dev_acc_1_1, 3), [(x, round(y, 3)) for x, y in dev_acc_1_1_dict.items()])
 | 
				
			||||||
            print("train/dev acc combo:", round(train_acc_1_1, 2), round(dev_acc_1_1, 2))
 | 
					            train_acc_1_1, train_acc_1_1_dict = _measure_accuracy(train_data, el_pipe)
 | 
				
			||||||
 | 
					            print("train acc combo:", round(train_acc_1_1, 3), [(x, round(y, 3)) for x, y in train_acc_1_1_dict.items()])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # baseline using only prior probabilities
 | 
					            # baseline using only prior probabilities
 | 
				
			||||||
            el_pipe.context_weight = 0
 | 
					            el_pipe.context_weight = 0
 | 
				
			||||||
            el_pipe.prior_weight = 1
 | 
					            el_pipe.prior_weight = 1
 | 
				
			||||||
            dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe)
 | 
					            dev_acc_0_1, dev_acc_0_1_dict = _measure_accuracy(dev_data, el_pipe)
 | 
				
			||||||
            train_acc_0_1 = _measure_accuracy(train_data, el_pipe)
 | 
					            print("dev acc prior:", round(dev_acc_0_1, 3), [(x, round(y, 3)) for x, y in dev_acc_0_1_dict.items()])
 | 
				
			||||||
            print("train/dev acc prior:", round(train_acc_0_1, 2), round(dev_acc_0_1, 2))
 | 
					            train_acc_0_1, train_acc_0_1_dict = _measure_accuracy(train_data, el_pipe)
 | 
				
			||||||
 | 
					            print("train acc prior:", round(train_acc_0_1, 3), [(x, round(y, 3)) for x, y in train_acc_0_1_dict.items()])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # using only context
 | 
					            # using only context
 | 
				
			||||||
            el_pipe.context_weight = 1
 | 
					            el_pipe.context_weight = 1
 | 
				
			||||||
            el_pipe.prior_weight = 0
 | 
					            el_pipe.prior_weight = 0
 | 
				
			||||||
            dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
 | 
					            dev_acc_1_0, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
 | 
				
			||||||
            train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
 | 
					            print("dev acc context:", round(dev_acc_1_0, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
 | 
				
			||||||
            print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2))
 | 
					            train_acc_1_0, train_acc_1_0_dict = _measure_accuracy(train_data, el_pipe)
 | 
				
			||||||
 | 
					            print("train acc context:", round(train_acc_1_0, 3), [(x, round(y, 3)) for x, y in train_acc_1_0_dict.items()])
 | 
				
			||||||
            print()
 | 
					            print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # reset for follow-up tests
 | 
					            # reset for follow-up tests
 | 
				
			||||||
            el_pipe.context_weight = 1
 | 
					            el_pipe.context_weight = 1
 | 
				
			||||||
            el_pipe.prior_weight = 1
 | 
					            el_pipe.prior_weight = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    if to_test_pipeline:
 | 
					    if to_test_pipeline:
 | 
				
			||||||
        print()
 | 
					        print()
 | 
				
			||||||
        print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
 | 
					        print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
 | 
				
			||||||
| 
						 | 
					@ -230,8 +234,8 @@ def run_pipeline():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _measure_accuracy(data, el_pipe):
 | 
					def _measure_accuracy(data, el_pipe):
 | 
				
			||||||
    correct = 0
 | 
					    correct_by_label = dict()
 | 
				
			||||||
    incorrect = 0
 | 
					    incorrect_by_label = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    docs = [d for d, g in data if len(d) > 0]
 | 
					    docs = [d for d, g in data if len(d) > 0]
 | 
				
			||||||
    docs = el_pipe.pipe(docs)
 | 
					    docs = el_pipe.pipe(docs)
 | 
				
			||||||
| 
						 | 
					@ -245,7 +249,7 @@ def _measure_accuracy(data, el_pipe):
 | 
				
			||||||
                correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
 | 
					                correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for ent in doc.ents:
 | 
					            for ent in doc.ents:
 | 
				
			||||||
                if ent.label_ == "PERSON":  # TODO: expand to other types
 | 
					                ent_label = ent.label_
 | 
				
			||||||
                pred_entity = ent.kb_id_
 | 
					                pred_entity = ent.kb_id_
 | 
				
			||||||
                start = ent.start_char
 | 
					                start = ent.start_char
 | 
				
			||||||
                end = ent.end_char
 | 
					                end = ent.end_char
 | 
				
			||||||
| 
						 | 
					@ -253,23 +257,45 @@ def _measure_accuracy(data, el_pipe):
 | 
				
			||||||
                # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
 | 
					                # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
 | 
				
			||||||
                if gold_entity is not None:
 | 
					                if gold_entity is not None:
 | 
				
			||||||
                    if gold_entity == pred_entity:
 | 
					                    if gold_entity == pred_entity:
 | 
				
			||||||
                            correct += 1
 | 
					                        correct = correct_by_label.get(ent_label, 0)
 | 
				
			||||||
 | 
					                        correct_by_label[ent_label] = correct + 1
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                            incorrect += 1
 | 
					                        incorrect = incorrect_by_label.get(ent_label, 0)
 | 
				
			||||||
 | 
					                        incorrect_by_label[ent_label] = incorrect + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            print("Error assessing accuracy", e)
 | 
					            print("Error assessing accuracy", e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    acc_by_label = dict()
 | 
				
			||||||
 | 
					    total_correct = 0
 | 
				
			||||||
 | 
					    total_incorrect = 0
 | 
				
			||||||
 | 
					    for label, correct in correct_by_label.items():
 | 
				
			||||||
 | 
					        incorrect = incorrect_by_label.get(label, 0)
 | 
				
			||||||
 | 
					        total_correct += correct
 | 
				
			||||||
 | 
					        total_incorrect += incorrect
 | 
				
			||||||
        if correct == incorrect == 0:
 | 
					        if correct == incorrect == 0:
 | 
				
			||||||
        return 0
 | 
					            acc_by_label[label] = 0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            acc_by_label[label] = correct / (correct + incorrect)
 | 
				
			||||||
 | 
					    acc = 0
 | 
				
			||||||
 | 
					    if not (total_correct == total_incorrect == 0):
 | 
				
			||||||
 | 
					        acc = total_correct / (total_correct + total_incorrect)
 | 
				
			||||||
 | 
					    return acc, acc_by_label
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    acc = correct / (correct + incorrect)
 | 
					
 | 
				
			||||||
    return acc
 | 
					def test_kb(kb):
 | 
				
			||||||
 | 
					    for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
 | 
				
			||||||
 | 
					        candidates = kb.get_candidates(mention)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("generating candidates for " + mention + " :")
 | 
				
			||||||
 | 
					        for c in candidates:
 | 
				
			||||||
 | 
					            print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
 | 
				
			||||||
 | 
					        print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_el_toy_example(nlp):
 | 
					def run_el_toy_example(nlp):
 | 
				
			||||||
    text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
 | 
					    text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
 | 
				
			||||||
           "Douglas reminds us to always bring our towel. " \
 | 
					           "Douglas reminds us to always bring our towel, even in China or Brazil. " \
 | 
				
			||||||
           "The main character in Doug's novel is the man Arthur Dent, " \
 | 
					           "The main character in Doug's novel is the man Arthur Dent, " \
 | 
				
			||||||
           "but Douglas doesn't write about George Washington or Homer Simpson."
 | 
					           "but Douglas doesn't write about George Washington or Homer Simpson."
 | 
				
			||||||
    doc = nlp(text)
 | 
					    doc = nlp(text)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,6 @@
 | 
				
			||||||
# coding: utf-8
 | 
					# coding: utf-8
 | 
				
			||||||
from __future__ import unicode_literals
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import re
 | 
					 | 
				
			||||||
import bz2
 | 
					import bz2
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
| 
						 | 
					@ -14,7 +13,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
 | 
				
			||||||
    """ Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
 | 
					    """ Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lang = 'en'
 | 
					    lang = 'en'
 | 
				
			||||||
    prop_filter = {'P31': {'Q5', 'Q15632617'}}     # currently defined as OR: one property suffices to be selected
 | 
					    # prop_filter = {'P31': {'Q5', 'Q15632617'}}     # currently defined as OR: one property suffices to be selected
 | 
				
			||||||
    site_filter = 'enwiki'
 | 
					    site_filter = 'enwiki'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    title_to_id = dict()
 | 
					    title_to_id = dict()
 | 
				
			||||||
| 
						 | 
					@ -41,18 +40,19 @@ def read_wikidata_entities_json(limit=None, to_print=False):
 | 
				
			||||||
                entry_type = obj["type"]
 | 
					                entry_type = obj["type"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if entry_type == "item":
 | 
					                if entry_type == "item":
 | 
				
			||||||
                    # filtering records on their properties
 | 
					                    # filtering records on their properties (currently disabled to get ALL data)
 | 
				
			||||||
                    keep = False
 | 
					                    # keep = False
 | 
				
			||||||
 | 
					                    keep = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    claims = obj["claims"]
 | 
					                    claims = obj["claims"]
 | 
				
			||||||
                    for prop, value_set in prop_filter.items():
 | 
					                    # for prop, value_set in prop_filter.items():
 | 
				
			||||||
                        claim_property = claims.get(prop, None)
 | 
					                        # claim_property = claims.get(prop, None)
 | 
				
			||||||
                        if claim_property:
 | 
					                        # if claim_property:
 | 
				
			||||||
                            for cp in claim_property:
 | 
					                            # for cp in claim_property:
 | 
				
			||||||
                                cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id')
 | 
					                                # cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id')
 | 
				
			||||||
                                cp_rank = cp['rank']
 | 
					                                # cp_rank = cp['rank']
 | 
				
			||||||
                                if cp_rank != "deprecated" and cp_id in value_set:
 | 
					                                # if cp_rank != "deprecated" and cp_id in value_set:
 | 
				
			||||||
                                    keep = True
 | 
					                                    # keep = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    if keep:
 | 
					                    if keep:
 | 
				
			||||||
                        unique_id = obj["id"]
 | 
					                        unique_id = obj["id"]
 | 
				
			||||||
| 
						 | 
					@ -70,6 +70,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
 | 
				
			||||||
                                    if to_print:
 | 
					                                    if to_print:
 | 
				
			||||||
                                        print("prop:", prop, cp_values)
 | 
					                                        print("prop:", prop, cp_values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        found_link = False
 | 
				
			||||||
                        if parse_sitelinks:
 | 
					                        if parse_sitelinks:
 | 
				
			||||||
                            site_value = obj["sitelinks"].get(site_filter, None)
 | 
					                            site_value = obj["sitelinks"].get(site_filter, None)
 | 
				
			||||||
                            if site_value:
 | 
					                            if site_value:
 | 
				
			||||||
| 
						 | 
					@ -77,6 +78,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
 | 
				
			||||||
                                if to_print:
 | 
					                                if to_print:
 | 
				
			||||||
                                    print(site_filter, ":", site)
 | 
					                                    print(site_filter, ":", site)
 | 
				
			||||||
                                title_to_id[site] = unique_id
 | 
					                                title_to_id[site] = unique_id
 | 
				
			||||||
 | 
					                                found_link = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        if parse_labels:
 | 
					                        if parse_labels:
 | 
				
			||||||
                            labels = obj["labels"]
 | 
					                            labels = obj["labels"]
 | 
				
			||||||
| 
						 | 
					@ -86,7 +88,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
 | 
				
			||||||
                                    if to_print:
 | 
					                                    if to_print:
 | 
				
			||||||
                                        print("label (" + lang + "):", lang_label["value"])
 | 
					                                        print("label (" + lang + "):", lang_label["value"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        if parse_descriptions:
 | 
					                        if found_link and parse_descriptions:
 | 
				
			||||||
                            descriptions = obj["descriptions"]
 | 
					                            descriptions = obj["descriptions"]
 | 
				
			||||||
                            if descriptions:
 | 
					                            if descriptions:
 | 
				
			||||||
                                lang_descr = descriptions.get(lang, None)
 | 
					                                lang_descr = descriptions.get(lang, None)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -175,7 +175,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
 | 
				
			||||||
        print("Total count:", total_count)
 | 
					        print("Total count:", total_count)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_entity_frequencies(count_input, entities):
 | 
					def get_all_frequencies(count_input):
 | 
				
			||||||
    entity_to_count = dict()
 | 
					    entity_to_count = dict()
 | 
				
			||||||
    with open(count_input, 'r', encoding='utf8') as csvfile:
 | 
					    with open(count_input, 'r', encoding='utf8') as csvfile:
 | 
				
			||||||
        csvreader = csv.reader(csvfile, delimiter='|')
 | 
					        csvreader = csv.reader(csvfile, delimiter='|')
 | 
				
			||||||
| 
						 | 
					@ -184,4 +184,5 @@ def get_entity_frequencies(count_input, entities):
 | 
				
			||||||
        for row in csvreader:
 | 
					        for row in csvreader:
 | 
				
			||||||
            entity_to_count[row[0]] = int(row[1])
 | 
					            entity_to_count[row[0]] = int(row[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return [entity_to_count.get(e, 0) for e in entities]
 | 
					    return entity_to_count
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user