mirror of
https://github.com/explosion/spaCy.git
synced 2025-05-29 18:23:06 +03:00
performance per entity type
This commit is contained in:
parent
b312f2d0e7
commit
81731907ba
|
@ -15,10 +15,10 @@ INPUT_DIM = 300 # dimension of pre-trained vectors
|
||||||
DESC_WIDTH = 64
|
DESC_WIDTH = 64
|
||||||
|
|
||||||
|
|
||||||
def create_kb(nlp, max_entities_per_alias, min_occ,
|
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
entity_def_output, entity_descr_output,
|
entity_def_output, entity_descr_output,
|
||||||
count_input, prior_prob_input, to_print=False):
|
count_input, prior_prob_input, to_print=False):
|
||||||
""" Create the knowledge base from Wikidata entries """
|
# Create the knowledge base from Wikidata entries
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
||||||
|
|
||||||
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
|
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
|
||||||
|
@ -37,21 +37,26 @@ def create_kb(nlp, max_entities_per_alias, min_occ,
|
||||||
title_to_id = _get_entity_to_id(entity_def_output)
|
title_to_id = _get_entity_to_id(entity_def_output)
|
||||||
id_to_descr = _get_id_to_description(entity_descr_output)
|
id_to_descr = _get_id_to_description(entity_descr_output)
|
||||||
|
|
||||||
title_list = list(title_to_id.keys())
|
|
||||||
|
|
||||||
# TODO: remove this filter (just for quicker testing of code)
|
|
||||||
# title_list = title_list[0:342]
|
|
||||||
# title_to_id = {t: title_to_id[t] for t in title_list}
|
|
||||||
|
|
||||||
entity_list = [title_to_id[x] for x in title_list]
|
|
||||||
|
|
||||||
# Currently keeping entities from the KB where there is no description - putting a default void description
|
|
||||||
description_list = [id_to_descr.get(x, "No description defined") for x in entity_list]
|
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * _get_entity_frequencies", datetime.datetime.now())
|
print(" * _get_entity_frequencies", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
entity_frequencies = wp.get_entity_frequencies(count_input=count_input, entities=title_list)
|
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
|
||||||
|
|
||||||
|
# filter the entities for in the KB by frequency, because there's just too much data otherwise
|
||||||
|
filtered_title_to_id = dict()
|
||||||
|
entity_list = list()
|
||||||
|
description_list = list()
|
||||||
|
frequency_list = list()
|
||||||
|
for title, entity in title_to_id.items():
|
||||||
|
freq = entity_frequencies.get(title, 0)
|
||||||
|
desc = id_to_descr.get(entity, None)
|
||||||
|
if desc and freq > min_entity_freq:
|
||||||
|
entity_list.append(entity)
|
||||||
|
description_list.append(desc)
|
||||||
|
frequency_list.append(freq)
|
||||||
|
filtered_title_to_id[title] = entity
|
||||||
|
|
||||||
|
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles")
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * train entity encoder", datetime.datetime.now())
|
print(" * train entity encoder", datetime.datetime.now())
|
||||||
|
@ -67,12 +72,12 @@ def create_kb(nlp, max_entities_per_alias, min_occ,
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
|
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
|
||||||
kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=embeddings)
|
kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * adding aliases", datetime.datetime.now())
|
print(" * adding aliases", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
_add_aliases(kb, title_to_id=title_to_id,
|
_add_aliases(kb, title_to_id=filtered_title_to_id,
|
||||||
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
|
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
|
||||||
prior_prob_input=prior_prob_input)
|
prior_prob_input=prior_prob_input)
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ ENTITY_FILE = "gold_entities.csv"
|
||||||
|
|
||||||
def create_training(entity_def_input, training_output):
|
def create_training(entity_def_input, training_output):
|
||||||
wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
|
wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
|
||||||
_process_wikipedia_texts(wp_to_id, training_output, limit=100000000) # TODO: full dataset 100000000
|
_process_wikipedia_texts(wp_to_id, training_output, limit=100000000)
|
||||||
|
|
||||||
|
|
||||||
def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
|
def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
|
||||||
|
|
|
@ -29,6 +29,7 @@ NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
|
||||||
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
||||||
|
|
||||||
MAX_CANDIDATES = 10
|
MAX_CANDIDATES = 10
|
||||||
|
MIN_ENTITY_FREQ = 200
|
||||||
MIN_PAIR_OCC = 5
|
MIN_PAIR_OCC = 5
|
||||||
DOC_SENT_CUTOFF = 2
|
DOC_SENT_CUTOFF = 2
|
||||||
EPOCHS = 10
|
EPOCHS = 10
|
||||||
|
@ -46,14 +47,14 @@ def run_pipeline():
|
||||||
# one-time methods to create KB and write to file
|
# one-time methods to create KB and write to file
|
||||||
to_create_prior_probs = False
|
to_create_prior_probs = False
|
||||||
to_create_entity_counts = False
|
to_create_entity_counts = False
|
||||||
to_create_kb = False # TODO: entity_defs should also contain entities not in the KB
|
to_create_kb = True
|
||||||
|
|
||||||
# read KB back in from file
|
# read KB back in from file
|
||||||
to_read_kb = False
|
to_read_kb = False
|
||||||
to_test_kb = False
|
to_test_kb = False
|
||||||
|
|
||||||
# create training dataset
|
# create training dataset
|
||||||
create_wp_training = True
|
create_wp_training = False
|
||||||
|
|
||||||
# train the EL pipe
|
# train the EL pipe
|
||||||
train_pipe = False
|
train_pipe = False
|
||||||
|
@ -84,13 +85,14 @@ def run_pipeline():
|
||||||
if to_create_kb:
|
if to_create_kb:
|
||||||
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
||||||
kb_1 = kb_creator.create_kb(nlp_1,
|
kb_1 = kb_creator.create_kb(nlp_1,
|
||||||
max_entities_per_alias=MAX_CANDIDATES,
|
max_entities_per_alias=MAX_CANDIDATES,
|
||||||
min_occ=MIN_PAIR_OCC,
|
min_entity_freq=MIN_ENTITY_FREQ,
|
||||||
entity_def_output=ENTITY_DEFS,
|
min_occ=MIN_PAIR_OCC,
|
||||||
entity_descr_output=ENTITY_DESCR,
|
entity_def_output=ENTITY_DEFS,
|
||||||
count_input=ENTITY_COUNTS,
|
entity_descr_output=ENTITY_DESCR,
|
||||||
prior_prob_input=PRIOR_PROB,
|
count_input=ENTITY_COUNTS,
|
||||||
to_print=False)
|
prior_prob_input=PRIOR_PROB,
|
||||||
|
to_print=False)
|
||||||
print("kb entities:", kb_1.get_size_entities())
|
print("kb entities:", kb_1.get_size_entities())
|
||||||
print("kb aliases:", kb_1.get_size_aliases())
|
print("kb aliases:", kb_1.get_size_aliases())
|
||||||
print()
|
print()
|
||||||
|
@ -112,7 +114,7 @@ def run_pipeline():
|
||||||
|
|
||||||
# test KB
|
# test KB
|
||||||
if to_test_kb:
|
if to_test_kb:
|
||||||
run_el.run_kb_toy_example(kb=kb_2)
|
test_kb(kb_2)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 5: create a training dataset from WP
|
# STEP 5: create a training dataset from WP
|
||||||
|
@ -121,10 +123,18 @@ def run_pipeline():
|
||||||
training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
||||||
|
|
||||||
# STEP 6: create the entity linking pipe
|
# STEP 6: create the entity linking pipe
|
||||||
|
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
|
||||||
|
el_pipe.set_kb(kb_2)
|
||||||
|
nlp_2.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
|
other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
|
||||||
|
with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
|
||||||
|
nlp_2.begin_training()
|
||||||
|
|
||||||
if train_pipe:
|
if train_pipe:
|
||||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||||
train_limit = 50
|
train_limit = 10
|
||||||
dev_limit = 10
|
dev_limit = 2
|
||||||
print("Training on", train_limit, "articles")
|
print("Training on", train_limit, "articles")
|
||||||
print("Dev testing on", dev_limit, "articles")
|
print("Dev testing on", dev_limit, "articles")
|
||||||
print()
|
print()
|
||||||
|
@ -141,14 +151,6 @@ def run_pipeline():
|
||||||
limit=dev_limit,
|
limit=dev_limit,
|
||||||
to_print=False)
|
to_print=False)
|
||||||
|
|
||||||
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
|
|
||||||
el_pipe.set_kb(kb_2)
|
|
||||||
nlp_2.add_pipe(el_pipe, last=True)
|
|
||||||
|
|
||||||
other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
|
|
||||||
with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
|
|
||||||
nlp_2.begin_training()
|
|
||||||
|
|
||||||
for itn in range(EPOCHS):
|
for itn in range(EPOCHS):
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_data)
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -180,30 +182,32 @@ def run_pipeline():
|
||||||
# print(" measuring accuracy 1-1")
|
# print(" measuring accuracy 1-1")
|
||||||
el_pipe.context_weight = 1
|
el_pipe.context_weight = 1
|
||||||
el_pipe.prior_weight = 1
|
el_pipe.prior_weight = 1
|
||||||
dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_1_1, dev_acc_1_1_dict = _measure_accuracy(dev_data, el_pipe)
|
||||||
train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
|
print("dev acc combo:", round(dev_acc_1_1, 3), [(x, round(y, 3)) for x, y in dev_acc_1_1_dict.items()])
|
||||||
print("train/dev acc combo:", round(train_acc_1_1, 2), round(dev_acc_1_1, 2))
|
train_acc_1_1, train_acc_1_1_dict = _measure_accuracy(train_data, el_pipe)
|
||||||
|
print("train acc combo:", round(train_acc_1_1, 3), [(x, round(y, 3)) for x, y in train_acc_1_1_dict.items()])
|
||||||
|
|
||||||
# baseline using only prior probabilities
|
# baseline using only prior probabilities
|
||||||
el_pipe.context_weight = 0
|
el_pipe.context_weight = 0
|
||||||
el_pipe.prior_weight = 1
|
el_pipe.prior_weight = 1
|
||||||
dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_0_1, dev_acc_0_1_dict = _measure_accuracy(dev_data, el_pipe)
|
||||||
train_acc_0_1 = _measure_accuracy(train_data, el_pipe)
|
print("dev acc prior:", round(dev_acc_0_1, 3), [(x, round(y, 3)) for x, y in dev_acc_0_1_dict.items()])
|
||||||
print("train/dev acc prior:", round(train_acc_0_1, 2), round(dev_acc_0_1, 2))
|
train_acc_0_1, train_acc_0_1_dict = _measure_accuracy(train_data, el_pipe)
|
||||||
|
print("train acc prior:", round(train_acc_0_1, 3), [(x, round(y, 3)) for x, y in train_acc_0_1_dict.items()])
|
||||||
|
|
||||||
# using only context
|
# using only context
|
||||||
el_pipe.context_weight = 1
|
el_pipe.context_weight = 1
|
||||||
el_pipe.prior_weight = 0
|
el_pipe.prior_weight = 0
|
||||||
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_1_0, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
|
||||||
train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
|
print("dev acc context:", round(dev_acc_1_0, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
|
||||||
print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2))
|
train_acc_1_0, train_acc_1_0_dict = _measure_accuracy(train_data, el_pipe)
|
||||||
|
print("train acc context:", round(train_acc_1_0, 3), [(x, round(y, 3)) for x, y in train_acc_1_0_dict.items()])
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# reset for follow-up tests
|
# reset for follow-up tests
|
||||||
el_pipe.context_weight = 1
|
el_pipe.context_weight = 1
|
||||||
el_pipe.prior_weight = 1
|
el_pipe.prior_weight = 1
|
||||||
|
|
||||||
|
|
||||||
if to_test_pipeline:
|
if to_test_pipeline:
|
||||||
print()
|
print()
|
||||||
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
||||||
|
@ -230,8 +234,8 @@ def run_pipeline():
|
||||||
|
|
||||||
|
|
||||||
def _measure_accuracy(data, el_pipe):
|
def _measure_accuracy(data, el_pipe):
|
||||||
correct = 0
|
correct_by_label = dict()
|
||||||
incorrect = 0
|
incorrect_by_label = dict()
|
||||||
|
|
||||||
docs = [d for d, g in data if len(d) > 0]
|
docs = [d for d, g in data if len(d) > 0]
|
||||||
docs = el_pipe.pipe(docs)
|
docs = el_pipe.pipe(docs)
|
||||||
|
@ -245,31 +249,53 @@ def _measure_accuracy(data, el_pipe):
|
||||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
if ent.label_ == "PERSON": # TODO: expand to other types
|
ent_label = ent.label_
|
||||||
pred_entity = ent.kb_id_
|
pred_entity = ent.kb_id_
|
||||||
start = ent.start_char
|
start = ent.start_char
|
||||||
end = ent.end_char
|
end = ent.end_char
|
||||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
if gold_entity is not None:
|
if gold_entity is not None:
|
||||||
if gold_entity == pred_entity:
|
if gold_entity == pred_entity:
|
||||||
correct += 1
|
correct = correct_by_label.get(ent_label, 0)
|
||||||
else:
|
correct_by_label[ent_label] = correct + 1
|
||||||
incorrect += 1
|
else:
|
||||||
|
incorrect = incorrect_by_label.get(ent_label, 0)
|
||||||
|
incorrect_by_label[ent_label] = incorrect + 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error assessing accuracy", e)
|
print("Error assessing accuracy", e)
|
||||||
|
|
||||||
if correct == incorrect == 0:
|
acc_by_label = dict()
|
||||||
return 0
|
total_correct = 0
|
||||||
|
total_incorrect = 0
|
||||||
|
for label, correct in correct_by_label.items():
|
||||||
|
incorrect = incorrect_by_label.get(label, 0)
|
||||||
|
total_correct += correct
|
||||||
|
total_incorrect += incorrect
|
||||||
|
if correct == incorrect == 0:
|
||||||
|
acc_by_label[label] = 0
|
||||||
|
else:
|
||||||
|
acc_by_label[label] = correct / (correct + incorrect)
|
||||||
|
acc = 0
|
||||||
|
if not (total_correct == total_incorrect == 0):
|
||||||
|
acc = total_correct / (total_correct + total_incorrect)
|
||||||
|
return acc, acc_by_label
|
||||||
|
|
||||||
acc = correct / (correct + incorrect)
|
|
||||||
return acc
|
def test_kb(kb):
|
||||||
|
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
||||||
|
candidates = kb.get_candidates(mention)
|
||||||
|
|
||||||
|
print("generating candidates for " + mention + " :")
|
||||||
|
for c in candidates:
|
||||||
|
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
def run_el_toy_example(nlp):
|
def run_el_toy_example(nlp):
|
||||||
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
||||||
"Douglas reminds us to always bring our towel. " \
|
"Douglas reminds us to always bring our towel, even in China or Brazil. " \
|
||||||
"The main character in Doug's novel is the man Arthur Dent, " \
|
"The main character in Doug's novel is the man Arthur Dent, " \
|
||||||
"but Douglas doesn't write about George Washington or Homer Simpson."
|
"but Douglas doesn't write about George Washington or Homer Simpson."
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import re
|
|
||||||
import bz2
|
import bz2
|
||||||
import json
|
import json
|
||||||
import datetime
|
import datetime
|
||||||
|
@ -14,7 +13,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
||||||
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
|
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
|
||||||
|
|
||||||
lang = 'en'
|
lang = 'en'
|
||||||
prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
|
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
|
||||||
site_filter = 'enwiki'
|
site_filter = 'enwiki'
|
||||||
|
|
||||||
title_to_id = dict()
|
title_to_id = dict()
|
||||||
|
@ -41,18 +40,19 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
||||||
entry_type = obj["type"]
|
entry_type = obj["type"]
|
||||||
|
|
||||||
if entry_type == "item":
|
if entry_type == "item":
|
||||||
# filtering records on their properties
|
# filtering records on their properties (currently disabled to get ALL data)
|
||||||
keep = False
|
# keep = False
|
||||||
|
keep = True
|
||||||
|
|
||||||
claims = obj["claims"]
|
claims = obj["claims"]
|
||||||
for prop, value_set in prop_filter.items():
|
# for prop, value_set in prop_filter.items():
|
||||||
claim_property = claims.get(prop, None)
|
# claim_property = claims.get(prop, None)
|
||||||
if claim_property:
|
# if claim_property:
|
||||||
for cp in claim_property:
|
# for cp in claim_property:
|
||||||
cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id')
|
# cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id')
|
||||||
cp_rank = cp['rank']
|
# cp_rank = cp['rank']
|
||||||
if cp_rank != "deprecated" and cp_id in value_set:
|
# if cp_rank != "deprecated" and cp_id in value_set:
|
||||||
keep = True
|
# keep = True
|
||||||
|
|
||||||
if keep:
|
if keep:
|
||||||
unique_id = obj["id"]
|
unique_id = obj["id"]
|
||||||
|
@ -70,6 +70,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
||||||
if to_print:
|
if to_print:
|
||||||
print("prop:", prop, cp_values)
|
print("prop:", prop, cp_values)
|
||||||
|
|
||||||
|
found_link = False
|
||||||
if parse_sitelinks:
|
if parse_sitelinks:
|
||||||
site_value = obj["sitelinks"].get(site_filter, None)
|
site_value = obj["sitelinks"].get(site_filter, None)
|
||||||
if site_value:
|
if site_value:
|
||||||
|
@ -77,6 +78,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
||||||
if to_print:
|
if to_print:
|
||||||
print(site_filter, ":", site)
|
print(site_filter, ":", site)
|
||||||
title_to_id[site] = unique_id
|
title_to_id[site] = unique_id
|
||||||
|
found_link = True
|
||||||
|
|
||||||
if parse_labels:
|
if parse_labels:
|
||||||
labels = obj["labels"]
|
labels = obj["labels"]
|
||||||
|
@ -86,7 +88,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
||||||
if to_print:
|
if to_print:
|
||||||
print("label (" + lang + "):", lang_label["value"])
|
print("label (" + lang + "):", lang_label["value"])
|
||||||
|
|
||||||
if parse_descriptions:
|
if found_link and parse_descriptions:
|
||||||
descriptions = obj["descriptions"]
|
descriptions = obj["descriptions"]
|
||||||
if descriptions:
|
if descriptions:
|
||||||
lang_descr = descriptions.get(lang, None)
|
lang_descr = descriptions.get(lang, None)
|
||||||
|
|
|
@ -175,7 +175,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
print("Total count:", total_count)
|
print("Total count:", total_count)
|
||||||
|
|
||||||
|
|
||||||
def get_entity_frequencies(count_input, entities):
|
def get_all_frequencies(count_input):
|
||||||
entity_to_count = dict()
|
entity_to_count = dict()
|
||||||
with open(count_input, 'r', encoding='utf8') as csvfile:
|
with open(count_input, 'r', encoding='utf8') as csvfile:
|
||||||
csvreader = csv.reader(csvfile, delimiter='|')
|
csvreader = csv.reader(csvfile, delimiter='|')
|
||||||
|
@ -184,4 +184,5 @@ def get_entity_frequencies(count_input, entities):
|
||||||
for row in csvreader:
|
for row in csvreader:
|
||||||
entity_to_count[row[0]] = int(row[1])
|
entity_to_count[row[0]] = int(row[1])
|
||||||
|
|
||||||
return [entity_to_count.get(e, 0) for e in entities]
|
return entity_to_count
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user