mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
fixes in kb and gold
This commit is contained in:
parent
4086c6ff60
commit
d833d4c358
|
@ -401,15 +401,13 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
|||
gold_start = int(start) - found_ent.sent.start_char
|
||||
gold_end = int(end) - found_ent.sent.start_char
|
||||
|
||||
# add both positive and negative examples (in random order just to be sure)
|
||||
# add both pos and neg examples (in random order)
|
||||
if kb:
|
||||
gold_entities = {}
|
||||
candidate_ids = [
|
||||
c.entity_ for c in kb.get_candidates(alias)
|
||||
]
|
||||
candidate_ids.append(
|
||||
wd_id
|
||||
) # in case the KB doesn't have it
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [c.entity_ for c in candidates]
|
||||
# add positive example in case the KB doesn't have it
|
||||
candidate_ids.append(wd_id)
|
||||
random.shuffle(candidate_ids)
|
||||
for kb_id in candidate_ids:
|
||||
entry = (gold_start, gold_end, kb_id)
|
||||
|
@ -418,7 +416,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
|||
else:
|
||||
gold_entities[entry] = 1.0
|
||||
else:
|
||||
gold_entities = {}
|
||||
entry = (gold_start, gold_end, wd_id)
|
||||
gold_entities = {entry: 1.0}
|
||||
|
||||
gold = GoldParse(doc=sent, links=gold_entities)
|
||||
data.append((sent, gold))
|
||||
|
|
|
@ -31,7 +31,7 @@ cdef class GoldParse:
|
|||
cdef public list ents
|
||||
cdef public dict brackets
|
||||
cdef public object cats
|
||||
cdef public list links
|
||||
cdef public dict links
|
||||
|
||||
cdef readonly list cand_to_gold
|
||||
cdef readonly list gold_to_cand
|
||||
|
|
|
@ -450,8 +450,10 @@ cdef class GoldParse:
|
|||
examples of a label to have the value 0.0. Labels not in the
|
||||
dictionary are treated as missing - the gradient for those labels
|
||||
will be zero.
|
||||
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
|
||||
representing the external ID of an entity in a knowledge base.
|
||||
links (dict): A dict with `(start_char, end_char, kb_id)` keys,
|
||||
representing the external ID of an entity in a knowledge base,
|
||||
and the values being either 1.0 or 0.0, indicating positive and
|
||||
negative examples, respectively.
|
||||
RETURNS (GoldParse): The newly constructed object.
|
||||
"""
|
||||
if words is None:
|
||||
|
|
29
spacy/kb.pyx
29
spacy/kb.pyx
|
@ -191,7 +191,7 @@ cdef class KnowledgeBase:
|
|||
|
||||
def get_candidates(self, unicode alias):
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash) # TODO: check for error? unit test !
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
|
||||
return [Candidate(kb=self,
|
||||
|
@ -199,12 +199,12 @@ cdef class KnowledgeBase:
|
|||
entity_freq=self._entries[entry_index].prob,
|
||||
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
||||
alias_hash=alias_hash,
|
||||
prior_prob=prob)
|
||||
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||
prior_prob=prior_prob)
|
||||
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||
if entry_index != 0]
|
||||
|
||||
def get_vector(self, unicode entity):
|
||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
|
||||
# Return an empty list if this entity is unknown in this KB
|
||||
if entity_hash not in self._entry_index:
|
||||
|
@ -213,6 +213,27 @@ cdef class KnowledgeBase:
|
|||
|
||||
return self._vectors_table[self._entries[entry_index].vector_index]
|
||||
|
||||
def get_prior_prob(self, unicode entity, unicode alias):
|
||||
""" Return the prior probability of a given alias being linked to a given entity,
|
||||
or return 0.0 when this combination is not known in the knowledge base"""
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
|
||||
# TODO: error ?
|
||||
if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
|
||||
return 0.0
|
||||
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
entry_index = self._entry_index[entity_hash]
|
||||
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
|
||||
if self._entries[entry_index].entity_hash == entity_hash:
|
||||
return prior_prob
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
def dump(self, loc):
|
||||
cdef Writer writer = Writer(loc)
|
||||
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
||||
|
|
|
@ -13,6 +13,11 @@ def nlp():
|
|||
return English()
|
||||
|
||||
|
||||
def assert_almost_equal(a, b):
|
||||
delta = 0.0001
|
||||
assert a - delta <= b <= a + delta
|
||||
|
||||
|
||||
def test_kb_valid_entities(nlp):
|
||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
|
@ -35,6 +40,10 @@ def test_kb_valid_entities(nlp):
|
|||
assert mykb.get_vector("Q2") == [2, 1, 0]
|
||||
assert mykb.get_vector("Q3") == [-1, -6, 5]
|
||||
|
||||
# test retrieval of prior probabilities
|
||||
assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
|
||||
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
|
||||
|
||||
|
||||
def test_kb_invalid_entities(nlp):
|
||||
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||
|
@ -99,12 +108,12 @@ def test_candidate_generation(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q1", prob=0.7, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
|
||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the relevant candidates
|
||||
|
@ -112,6 +121,12 @@ def test_candidate_generation(nlp):
|
|||
assert len(mykb.get_candidates("adam")) == 1
|
||||
assert len(mykb.get_candidates("shrubbery")) == 0
|
||||
|
||||
# test the content of the candidates
|
||||
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
|
||||
assert mykb.get_candidates("adam")[0].alias_ == "adam"
|
||||
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
|
||||
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
|
||||
|
||||
|
||||
def test_preserving_links_asdoc(nlp):
|
||||
"""Test that Span.as_doc preserves the existing entity links"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user