fixes in kb and gold

This commit is contained in:
svlandeg 2019-07-17 17:18:26 +02:00
parent 4086c6ff60
commit d833d4c358
5 changed files with 54 additions and 17 deletions

View File

@ -401,15 +401,13 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char
# add both positive and negative examples (in random order just to be sure)
# add both pos and neg examples (in random order)
if kb:
gold_entities = {}
candidate_ids = [
c.entity_ for c in kb.get_candidates(alias)
]
candidate_ids.append(
wd_id
) # in case the KB doesn't have it
candidates = kb.get_candidates(alias)
candidate_ids = [c.entity_ for c in candidates]
# add positive example in case the KB doesn't have it
candidate_ids.append(wd_id)
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
entry = (gold_start, gold_end, kb_id)
@ -418,7 +416,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
else:
gold_entities[entry] = 1.0
else:
gold_entities = {}
entry = (gold_start, gold_end, wd_id)
gold_entities = {entry: 1.0}
gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold))

View File

@ -31,7 +31,7 @@ cdef class GoldParse:
cdef public list ents
cdef public dict brackets
cdef public object cats
cdef public list links
cdef public dict links
cdef readonly list cand_to_gold
cdef readonly list gold_to_cand

View File

@ -450,8 +450,10 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels
will be zero.
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
representing the external ID of an entity in a knowledge base.
links (dict): A dict with `(start_char, end_char, kb_id)` keys,
representing the external ID of an entity in a knowledge base,
and the values being either 1.0 or 0.0, indicating positive and
negative examples, respectively.
RETURNS (GoldParse): The newly constructed object.
"""
if words is None:

View File

@ -191,7 +191,7 @@ cdef class KnowledgeBase:
def get_candidates(self, unicode alias):
cdef hash_t alias_hash = self.vocab.strings[alias]
alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_index = <int64_t>self._alias_index.get(alias_hash) # TODO: check for error? unit test !
alias_entry = self._aliases_table[alias_index]
return [Candidate(kb=self,
@ -199,12 +199,12 @@ cdef class KnowledgeBase:
entity_freq=self._entries[entry_index].prob,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
alias_hash=alias_hash,
prior_prob=prob)
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
prior_prob=prior_prob)
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0]
def get_vector(self, unicode entity):
cdef hash_t entity_hash = self.vocab.strings.add(entity)
cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB
if entity_hash not in self._entry_index:
@ -213,6 +213,27 @@ cdef class KnowledgeBase:
return self._vectors_table[self._entries[entry_index].vector_index]
def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base"""
cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity]
# TODO: error ?
if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
return 0.0
alias_index = <int64_t>self._alias_index.get(alias_hash)
entry_index = self._entry_index[entity_hash]
alias_entry = self._aliases_table[alias_index]
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
if self._entries[entry_index].entity_hash == entity_hash:
return prior_prob
return 0.0
def dump(self, loc):
cdef Writer writer = Writer(loc)
writer.write_header(self.get_size_entities(), self.entity_vector_length)

View File

@ -13,6 +13,11 @@ def nlp():
return English()
def assert_almost_equal(a, b):
delta = 0.0001
assert a - delta <= b <= a + delta
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
@ -35,6 +40,10 @@ def test_kb_valid_entities(nlp):
assert mykb.get_vector("Q2") == [2, 1, 0]
assert mykb.get_vector("Q3") == [-1, -6, 5]
# test retrieval of prior probabilities
assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
@ -99,12 +108,12 @@ def test_candidate_generation(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity="Q1", prob=0.7, entity_vector=[1])
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
@ -112,6 +121,12 @@ def test_candidate_generation(nlp):
assert len(mykb.get_candidates("adam")) == 1
assert len(mykb.get_candidates("shrubbery")) == 0
# test the content of the candidates
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
assert mykb.get_candidates("adam")[0].alias_ == "adam"
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links"""