fixes in kb and gold

This commit is contained in:
svlandeg 2019-07-17 17:18:26 +02:00
parent 4086c6ff60
commit d833d4c358
5 changed files with 54 additions and 17 deletions

View File

@ -401,15 +401,13 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
gold_start = int(start) - found_ent.sent.start_char gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char gold_end = int(end) - found_ent.sent.start_char
# add both positive and negative examples (in random order just to be sure) # add both pos and neg examples (in random order)
if kb: if kb:
gold_entities = {} gold_entities = {}
candidate_ids = [ candidates = kb.get_candidates(alias)
c.entity_ for c in kb.get_candidates(alias) candidate_ids = [c.entity_ for c in candidates]
] # add positive example in case the KB doesn't have it
candidate_ids.append( candidate_ids.append(wd_id)
wd_id
) # in case the KB doesn't have it
random.shuffle(candidate_ids) random.shuffle(candidate_ids)
for kb_id in candidate_ids: for kb_id in candidate_ids:
entry = (gold_start, gold_end, kb_id) entry = (gold_start, gold_end, kb_id)
@ -418,7 +416,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
else: else:
gold_entities[entry] = 1.0 gold_entities[entry] = 1.0
else: else:
gold_entities = {} entry = (gold_start, gold_end, wd_id)
gold_entities = {entry: 1.0}
gold = GoldParse(doc=sent, links=gold_entities) gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold)) data.append((sent, gold))

View File

@ -31,7 +31,7 @@ cdef class GoldParse:
cdef public list ents cdef public list ents
cdef public dict brackets cdef public dict brackets
cdef public object cats cdef public object cats
cdef public list links cdef public dict links
cdef readonly list cand_to_gold cdef readonly list cand_to_gold
cdef readonly list gold_to_cand cdef readonly list gold_to_cand

View File

@ -450,8 +450,10 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels dictionary are treated as missing - the gradient for those labels
will be zero. will be zero.
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples, links (dict): A dict with `(start_char, end_char, kb_id)` keys,
representing the external ID of an entity in a knowledge base. representing the external ID of an entity in a knowledge base,
and the values being either 1.0 or 0.0, indicating positive and
negative examples, respectively.
RETURNS (GoldParse): The newly constructed object. RETURNS (GoldParse): The newly constructed object.
""" """
if words is None: if words is None:

View File

@ -191,7 +191,7 @@ cdef class KnowledgeBase:
def get_candidates(self, unicode alias): def get_candidates(self, unicode alias):
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
alias_index = <int64_t>self._alias_index.get(alias_hash) alias_index = <int64_t>self._alias_index.get(alias_hash) # TODO: check for error? unit test !
alias_entry = self._aliases_table[alias_index] alias_entry = self._aliases_table[alias_index]
return [Candidate(kb=self, return [Candidate(kb=self,
@ -199,12 +199,12 @@ cdef class KnowledgeBase:
entity_freq=self._entries[entry_index].prob, entity_freq=self._entries[entry_index].prob,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index], entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
alias_hash=alias_hash, alias_hash=alias_hash,
prior_prob=prob) prior_prob=prior_prob)
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs) for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0] if entry_index != 0]
def get_vector(self, unicode entity): def get_vector(self, unicode entity):
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB # Return an empty list if this entity is unknown in this KB
if entity_hash not in self._entry_index: if entity_hash not in self._entry_index:
@ -213,6 +213,27 @@ cdef class KnowledgeBase:
return self._vectors_table[self._entries[entry_index].vector_index] return self._vectors_table[self._entries[entry_index].vector_index]
def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base"""
cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity]
# TODO: error ?
if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
return 0.0
alias_index = <int64_t>self._alias_index.get(alias_hash)
entry_index = self._entry_index[entity_hash]
alias_entry = self._aliases_table[alias_index]
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
if self._entries[entry_index].entity_hash == entity_hash:
return prior_prob
return 0.0
def dump(self, loc): def dump(self, loc):
cdef Writer writer = Writer(loc) cdef Writer writer = Writer(loc)
writer.write_header(self.get_size_entities(), self.entity_vector_length) writer.write_header(self.get_size_entities(), self.entity_vector_length)

View File

@ -13,6 +13,11 @@ def nlp():
return English() return English()
def assert_almost_equal(a, b):
delta = 0.0001
assert a - delta <= b <= a + delta
def test_kb_valid_entities(nlp): def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases""" """Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
@ -35,6 +40,10 @@ def test_kb_valid_entities(nlp):
assert mykb.get_vector("Q2") == [2, 1, 0] assert mykb.get_vector("Q2") == [2, 1, 0]
assert mykb.get_vector("Q3") == [-1, -6, 5] assert mykb.get_vector("Q3") == [-1, -6, 5]
# test retrieval of prior probabilities
assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
def test_kb_invalid_entities(nlp): def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity""" """Test the invalid construction of a KB with an alias linked to a non-existing entity"""
@ -99,12 +108,12 @@ def test_candidate_generation(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", prob=0.7, entity_vector=[1])
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases # adding aliases
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2]) mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates # test the size of the relevant candidates
@ -112,6 +121,12 @@ def test_candidate_generation(nlp):
assert len(mykb.get_candidates("adam")) == 1 assert len(mykb.get_candidates("adam")) == 1
assert len(mykb.get_candidates("shrubbery")) == 0 assert len(mykb.get_candidates("shrubbery")) == 0
# test the content of the candidates
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
assert mykb.get_candidates("adam")[0].alias_ == "adam"
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
def test_preserving_links_asdoc(nlp): def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links""" """Test that Span.as_doc preserves the existing entity links"""